Skip to content

Commit d64464c

Browse files
authored
Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input * Cleanup * Fix linear * Small re-naming * Remove console.log
1 parent ccd3983 commit d64464c

File tree

11 files changed

+389
-81
lines changed

11 files changed

+389
-81
lines changed

backend/onyx/configs/app_configs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@
374374
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
375375

376376
# Egnyte specific configs
377-
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
378377
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
379378
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
380379

backend/onyx/connectors/egnyte/connector.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import IO
88
from urllib.parse import quote
99

10-
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
10+
from pydantic import Field
11+
1112
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
1213
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
1314
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -124,6 +125,15 @@ def _process_egnyte_file(
124125

125126

126127
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
128+
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
129+
egnyte_domain: str = Field(
130+
title="Egnyte Domain",
131+
description=(
132+
"The domain for the Egnyte instance "
133+
"(e.g. 'company' for company.egnyte.com)"
134+
),
135+
)
136+
127137
def __init__(
128138
self,
129139
folder_path: str | None = None,
@@ -139,15 +149,20 @@ def oauth_id(cls) -> DocumentSource:
139149
return DocumentSource.EGNYTE
140150

141151
@classmethod
142-
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
152+
def oauth_authorization_url(
153+
cls,
154+
base_domain: str,
155+
state: str,
156+
additional_kwargs: dict[str, str],
157+
) -> str:
143158
if not EGNYTE_CLIENT_ID:
144159
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
145-
if not EGNYTE_BASE_DOMAIN:
146-
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
160+
161+
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
147162

148163
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
149164
return (
150-
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
165+
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
151166
f"?client_id={EGNYTE_CLIENT_ID}"
152167
f"&redirect_uri={callback_uri}"
153168
f"&scope=Egnyte.filesystem"
@@ -156,17 +171,23 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
156171
)
157172

158173
@classmethod
159-
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
174+
def oauth_code_to_token(
175+
cls,
176+
base_domain: str,
177+
code: str,
178+
additional_kwargs: dict[str, str],
179+
) -> dict[str, Any]:
160180
if not EGNYTE_CLIENT_ID:
161181
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
162182
if not EGNYTE_CLIENT_SECRET:
163183
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
164-
if not EGNYTE_BASE_DOMAIN:
165-
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
184+
185+
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
166186

167187
# Exchange code for token
168-
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
188+
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
169189
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
190+
170191
data = {
171192
"client_id": EGNYTE_CLIENT_ID,
172193
"client_secret": EGNYTE_CLIENT_SECRET,
@@ -191,7 +212,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
191212

192213
token_data = response.json()
193214
return {
194-
"domain": EGNYTE_BASE_DOMAIN,
215+
"domain": oauth_kwargs.egnyte_domain,
195216
"access_token": token_data["access_token"],
196217
}
197218

@@ -215,7 +236,7 @@ def _get_files_list(
215236
"list_content": True,
216237
}
217238

218-
url_encoded_path = quote(path or "", safe="")
239+
url_encoded_path = quote(path or "")
219240
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
220241
response = request_with_retries(
221242
method="GET", url=url, headers=headers, params=params
@@ -271,7 +292,7 @@ def _process_files(
271292
headers = {
272293
"Authorization": f"Bearer {self.access_token}",
273294
}
274-
url_encoded_path = quote(file["path"], safe="")
295+
url_encoded_path = quote(file["path"])
275296
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
276297
response = request_with_retries(
277298
method="GET",

backend/onyx/connectors/interfaces.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections.abc import Iterator
33
from typing import Any
44

5+
from pydantic import BaseModel
6+
57
from onyx.configs.constants import DocumentSource
68
from onyx.connectors.models import Document
79
from onyx.connectors.models import SlimDocument
@@ -66,19 +68,33 @@ def retrieve_all_slim_documents(
6668

6769

6870
class OAuthConnector(BaseConnector):
71+
class AdditionalOauthKwargs(BaseModel):
72+
# if overridden, all fields should be str type
73+
pass
74+
6975
@classmethod
7076
@abc.abstractmethod
7177
def oauth_id(cls) -> DocumentSource:
7278
raise NotImplementedError
7379

7480
@classmethod
7581
@abc.abstractmethod
76-
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
82+
def oauth_authorization_url(
83+
cls,
84+
base_domain: str,
85+
state: str,
86+
additional_kwargs: dict[str, str],
87+
) -> str:
7788
raise NotImplementedError
7889

7990
@classmethod
8091
@abc.abstractmethod
81-
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
92+
def oauth_code_to_token(
93+
cls,
94+
base_domain: str,
95+
code: str,
96+
additional_kwargs: dict[str, str],
97+
) -> dict[str, Any]:
8298
raise NotImplementedError
8399

84100

backend/onyx/connectors/linear/connector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def oauth_id(cls) -> DocumentSource:
7777
return DocumentSource.LINEAR
7878

7979
@classmethod
80-
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
80+
def oauth_authorization_url(
81+
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
82+
) -> str:
8183
if not LINEAR_CLIENT_ID:
8284
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
8385

@@ -92,7 +94,9 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
9294
)
9395

9496
@classmethod
95-
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
97+
def oauth_code_to_token(
98+
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
99+
) -> dict[str, Any]:
96100
data = {
97101
"code": code,
98102
"redirect_uri": get_oauth_callback_uri(

backend/onyx/server/documents/standard_oauth.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import uuid
23
from typing import Annotated
34
from typing import cast
@@ -6,7 +7,9 @@
67
from fastapi import Depends
78
from fastapi import HTTPException
89
from fastapi import Query
10+
from fastapi import Request
911
from pydantic import BaseModel
12+
from pydantic import ValidationError
1013
from sqlalchemy.orm import Session
1114

1215
from onyx.auth.users import current_user
@@ -28,6 +31,8 @@
2831

2932
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
3033
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
34+
_DESIRED_RETURN_URL_KEY = "desired_return_url"
35+
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
3136

3237
# Cache for OAuth connectors, populated at module load time
3338
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
@@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
5156
_discover_oauth_connectors()
5257

5358

59+
def _get_additional_kwargs(
60+
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
61+
) -> dict[str, str]:
62+
# get additional kwargs from request
63+
# e.g. anything except for desired_return_url
64+
additional_kwargs_dict = {
65+
k: v for k, v in request.query_params.items() if k not in args_to_ignore
66+
}
67+
try:
68+
# validate
69+
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
70+
except ValidationError:
71+
raise HTTPException(
72+
status_code=400,
73+
detail=(
74+
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
75+
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
76+
),
77+
)
78+
79+
return additional_kwargs_dict
80+
81+
5482
class AuthorizeResponse(BaseModel):
5583
redirect_url: str
5684

5785

5886
@router.get("/authorize/{source}")
5987
def oauth_authorize(
88+
request: Request,
6089
source: DocumentSource,
6190
desired_return_url: Annotated[str | None, Query()] = None,
6291
_: User = Depends(current_user),
@@ -71,19 +100,32 @@ def oauth_authorize(
71100
connector_cls = oauth_connectors[source]
72101
base_url = WEB_DOMAIN
73102

103+
# get additional kwargs from request
104+
# e.g. anything except for desired_return_url
105+
additional_kwargs = _get_additional_kwargs(
106+
request, connector_cls, ["desired_return_url"]
107+
)
108+
74109
# store state in redis
75110
if not desired_return_url:
76111
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
77112
redis_client = get_redis_client(tenant_id=tenant_id)
78113
state = str(uuid.uuid4())
79114
redis_client.set(
80115
_OAUTH_STATE_KEY_FMT.format(state=state),
81-
desired_return_url,
116+
json.dumps(
117+
{
118+
_DESIRED_RETURN_URL_KEY: desired_return_url,
119+
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
120+
}
121+
),
82122
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
83123
)
84124

85125
return AuthorizeResponse(
86-
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
126+
redirect_url=connector_cls.oauth_authorization_url(
127+
base_url, state, additional_kwargs
128+
)
87129
)
88130

89131

@@ -110,15 +152,18 @@ def oauth_callback(
110152

111153
# get state from redis
112154
redis_client = get_redis_client(tenant_id=tenant_id)
113-
original_url_bytes = cast(
155+
oauth_state_bytes = cast(
114156
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
115157
)
116-
if not original_url_bytes:
158+
if not oauth_state_bytes:
117159
raise HTTPException(status_code=400, detail="Invalid OAuth state")
118-
original_url = original_url_bytes.decode("utf-8")
160+
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
161+
162+
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
163+
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
119164

120165
base_url = WEB_DOMAIN
121-
token_info = connector_cls.oauth_code_to_token(base_url, code)
166+
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
122167

123168
# Create a new credential with the token info
124169
credential_data = CredentialBase(
@@ -136,8 +181,52 @@ def oauth_callback(
136181

137182
return CallbackResponse(
138183
redirect_url=(
139-
f"{original_url}?credentialId={credential.id}"
140-
if "?" not in original_url
141-
else f"{original_url}&credentialId={credential.id}"
184+
f"{desired_return_url}?credentialId={credential.id}"
185+
if "?" not in desired_return_url
186+
else f"{desired_return_url}&credentialId={credential.id}"
187+
)
188+
)
189+
190+
191+
class OAuthAdditionalKwargDescription(BaseModel):
192+
name: str
193+
display_name: str
194+
description: str
195+
196+
197+
class OAuthDetails(BaseModel):
198+
oauth_enabled: bool
199+
additional_kwargs: list[OAuthAdditionalKwargDescription]
200+
201+
202+
@router.get("/details/{source}")
203+
def oauth_details(
204+
source: DocumentSource,
205+
_: User = Depends(current_user),
206+
) -> OAuthDetails:
207+
oauth_connectors = _discover_oauth_connectors()
208+
209+
if source not in oauth_connectors:
210+
return OAuthDetails(
211+
oauth_enabled=False,
212+
additional_kwargs=[],
213+
)
214+
215+
connector_cls = oauth_connectors[source]
216+
217+
additional_kwarg_descriptions = []
218+
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
219+
"properties"
220+
].items():
221+
additional_kwarg_descriptions.append(
222+
OAuthAdditionalKwargDescription(
223+
name=key,
224+
display_name=value.get("title", key),
225+
description=value.get("description", ""),
226+
)
142227
)
228+
229+
return OAuthDetails(
230+
oauth_enabled=True,
231+
additional_kwargs=additional_kwarg_descriptions,
143232
)

deployment/docker_compose/docker-compose.dev.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ services:
196196
# Egnyte OAuth Configs
197197
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
198198
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
199-
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
200199
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
201200
# Celery Configs (defaults are set in the supervisord.conf file.
202201
# prefer doing that to have one source of defaults)

0 commit comments

Comments
 (0)