1
+ import json
1
2
import uuid
2
3
from typing import Annotated
3
4
from typing import cast
6
7
from fastapi import Depends
7
8
from fastapi import HTTPException
8
9
from fastapi import Query
10
+ from fastapi import Request
9
11
from pydantic import BaseModel
12
+ from pydantic import ValidationError
10
13
from sqlalchemy .orm import Session
11
14
12
15
from onyx .auth .users import current_user
28
31
29
32
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
30
33
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
34
+ _DESIRED_RETURN_URL_KEY = "desired_return_url"
35
+ _ADDITIONAL_KWARGS_KEY = "additional_kwargs"
31
36
32
37
# Cache for OAuth connectors, populated at module load time
33
38
_OAUTH_CONNECTORS : dict [DocumentSource , type [OAuthConnector ]] = {}
@@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
51
56
_discover_oauth_connectors ()
52
57
53
58
59
+ def _get_additional_kwargs (
60
+ request : Request , connector_cls : type [OAuthConnector ], args_to_ignore : list [str ]
61
+ ) -> dict [str , str ]:
62
+ # get additional kwargs from request
63
+ # e.g. anything except for desired_return_url
64
+ additional_kwargs_dict = {
65
+ k : v for k , v in request .query_params .items () if k not in args_to_ignore
66
+ }
67
+ try :
68
+ # validate
69
+ connector_cls .AdditionalOauthKwargs (** additional_kwargs_dict )
70
+ except ValidationError :
71
+ raise HTTPException (
72
+ status_code = 400 ,
73
+ detail = (
74
+ f"Invalid additional kwargs. Got { additional_kwargs_dict } , expected "
75
+ f"{ connector_cls .AdditionalOauthKwargs .model_json_schema ()} "
76
+ ),
77
+ )
78
+
79
+ return additional_kwargs_dict
80
+
81
+
54
82
class AuthorizeResponse (BaseModel ):
55
83
redirect_url : str
56
84
57
85
58
86
@router .get ("/authorize/{source}" )
59
87
def oauth_authorize (
88
+ request : Request ,
60
89
source : DocumentSource ,
61
90
desired_return_url : Annotated [str | None , Query ()] = None ,
62
91
_ : User = Depends (current_user ),
@@ -71,19 +100,32 @@ def oauth_authorize(
71
100
connector_cls = oauth_connectors [source ]
72
101
base_url = WEB_DOMAIN
73
102
103
+ # get additional kwargs from request
104
+ # e.g. anything except for desired_return_url
105
+ additional_kwargs = _get_additional_kwargs (
106
+ request , connector_cls , ["desired_return_url" ]
107
+ )
108
+
74
109
# store state in redis
75
110
if not desired_return_url :
76
111
desired_return_url = f"{ base_url } /admin/connectors/{ source } ?step=0"
77
112
redis_client = get_redis_client (tenant_id = tenant_id )
78
113
state = str (uuid .uuid4 ())
79
114
redis_client .set (
80
115
_OAUTH_STATE_KEY_FMT .format (state = state ),
81
- desired_return_url ,
116
+ json .dumps (
117
+ {
118
+ _DESIRED_RETURN_URL_KEY : desired_return_url ,
119
+ _ADDITIONAL_KWARGS_KEY : additional_kwargs ,
120
+ }
121
+ ),
82
122
ex = _OAUTH_STATE_EXPIRATION_SECONDS ,
83
123
)
84
124
85
125
return AuthorizeResponse (
86
- redirect_url = connector_cls .oauth_authorization_url (base_url , state )
126
+ redirect_url = connector_cls .oauth_authorization_url (
127
+ base_url , state , additional_kwargs
128
+ )
87
129
)
88
130
89
131
@@ -110,15 +152,18 @@ def oauth_callback(
110
152
111
153
# get state from redis
112
154
redis_client = get_redis_client (tenant_id = tenant_id )
113
- original_url_bytes = cast (
155
+ oauth_state_bytes = cast (
114
156
bytes , redis_client .get (_OAUTH_STATE_KEY_FMT .format (state = state ))
115
157
)
116
- if not original_url_bytes :
158
+ if not oauth_state_bytes :
117
159
raise HTTPException (status_code = 400 , detail = "Invalid OAuth state" )
118
- original_url = original_url_bytes .decode ("utf-8" )
160
+ oauth_state = json .loads (oauth_state_bytes .decode ("utf-8" ))
161
+
162
+ desired_return_url = cast (str , oauth_state [_DESIRED_RETURN_URL_KEY ])
163
+ additional_kwargs = cast (dict [str , str ], oauth_state [_ADDITIONAL_KWARGS_KEY ])
119
164
120
165
base_url = WEB_DOMAIN
121
- token_info = connector_cls .oauth_code_to_token (base_url , code )
166
+ token_info = connector_cls .oauth_code_to_token (base_url , code , additional_kwargs )
122
167
123
168
# Create a new credential with the token info
124
169
credential_data = CredentialBase (
@@ -136,8 +181,52 @@ def oauth_callback(
136
181
137
182
return CallbackResponse (
138
183
redirect_url = (
139
- f"{ original_url } ?credentialId={ credential .id } "
140
- if "?" not in original_url
141
- else f"{ original_url } &credentialId={ credential .id } "
184
+ f"{ desired_return_url } ?credentialId={ credential .id } "
185
+ if "?" not in desired_return_url
186
+ else f"{ desired_return_url } &credentialId={ credential .id } "
187
+ )
188
+ )
189
+
190
+
191
+ class OAuthAdditionalKwargDescription (BaseModel ):
192
+ name : str
193
+ display_name : str
194
+ description : str
195
+
196
+
197
+ class OAuthDetails (BaseModel ):
198
+ oauth_enabled : bool
199
+ additional_kwargs : list [OAuthAdditionalKwargDescription ]
200
+
201
+
202
+ @router .get ("/details/{source}" )
203
+ def oauth_details (
204
+ source : DocumentSource ,
205
+ _ : User = Depends (current_user ),
206
+ ) -> OAuthDetails :
207
+ oauth_connectors = _discover_oauth_connectors ()
208
+
209
+ if source not in oauth_connectors :
210
+ return OAuthDetails (
211
+ oauth_enabled = False ,
212
+ additional_kwargs = [],
213
+ )
214
+
215
+ connector_cls = oauth_connectors [source ]
216
+
217
+ additional_kwarg_descriptions = []
218
+ for key , value in connector_cls .AdditionalOauthKwargs .model_json_schema ()[
219
+ "properties"
220
+ ].items ():
221
+ additional_kwarg_descriptions .append (
222
+ OAuthAdditionalKwargDescription (
223
+ name = key ,
224
+ display_name = value .get ("title" , key ),
225
+ description = value .get ("description" , "" ),
226
+ )
142
227
)
228
+
229
+ return OAuthDetails (
230
+ oauth_enabled = True ,
231
+ additional_kwargs = additional_kwarg_descriptions ,
143
232
)
0 commit comments