Skip to content

Commit b09b9a8

Browse files
Implement Groups feature for non-EE version
- Add backend user group CRUD operations and API endpoints - Create Groups management UI with creation/editing forms - Enable user groups functionality for non-Enterprise version - Add group assignment support for users, connectors, document sets, assistants - Update token rate limits page to support user groups without EE dependency - Remove Enterprise-only conditional logic from useUserGroups hook This implements a complete Groups feature modeled after the Enterprise version but without reusing any EE code, making it available in the non-EE version. Co-Authored-By: Sanjay Akut <sanjay@tukatek.com>
1 parent 8ffd51a commit b09b9a8

File tree

9 files changed

+1212
-93
lines changed

9 files changed

+1212
-93
lines changed

backend/onyx/db/user_group.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
from sqlalchemy import and_
2+
from sqlalchemy import select
3+
from sqlalchemy.orm import Session
4+
5+
from onyx.db.models import User
6+
from onyx.db.models import UserGroup
7+
from onyx.db.models import User__UserGroup
8+
from onyx.db.models import UserGroup__ConnectorCredentialPair
9+
from onyx.db.models import Persona__UserGroup
10+
from onyx.db.models import DocumentSet__UserGroup
11+
from onyx.db.models import Credential__UserGroup
12+
from onyx.db.models import ConnectorCredentialPair
13+
from onyx.db.models import Persona
14+
from onyx.db.models import DocumentSet
15+
from onyx.db.models import Credential
16+
17+
18+
def fetch_user_groups(db_session: Session) -> list[UserGroup]:
19+
"""Fetch all user groups"""
20+
return list(db_session.scalars(select(UserGroup)).all())
21+
22+
23+
def fetch_user_group_by_id(db_session: Session, user_group_id: int) -> UserGroup | None:
24+
"""Fetch a specific user group by ID"""
25+
return db_session.scalar(select(UserGroup).where(UserGroup.id == user_group_id))
26+
27+
28+
def insert_user_group(
29+
db_session: Session,
30+
name: str,
31+
user_ids: list[str] | None = None,
32+
cc_pair_ids: list[int] | None = None,
33+
persona_ids: list[int] | None = None,
34+
document_set_ids: list[int] | None = None,
35+
credential_ids: list[int] | None = None,
36+
) -> UserGroup:
37+
"""Create a new user group with optional initial assignments"""
38+
user_group = UserGroup(
39+
name=name,
40+
is_up_to_date=True,
41+
is_up_for_deletion=False,
42+
)
43+
db_session.add(user_group)
44+
db_session.flush()
45+
46+
if user_ids:
47+
for user_id in user_ids:
48+
user_group_relationship = User__UserGroup(
49+
user_group_id=user_group.id,
50+
user_id=user_id,
51+
)
52+
db_session.add(user_group_relationship)
53+
54+
if cc_pair_ids:
55+
for cc_pair_id in cc_pair_ids:
56+
cc_pair_relationship = UserGroup__ConnectorCredentialPair(
57+
user_group_id=user_group.id,
58+
cc_pair_id=cc_pair_id,
59+
is_current=True,
60+
)
61+
db_session.add(cc_pair_relationship)
62+
63+
if persona_ids:
64+
for persona_id in persona_ids:
65+
persona_relationship = Persona__UserGroup(
66+
persona_id=persona_id,
67+
user_group_id=user_group.id,
68+
)
69+
db_session.add(persona_relationship)
70+
71+
if document_set_ids:
72+
for document_set_id in document_set_ids:
73+
document_set_relationship = DocumentSet__UserGroup(
74+
document_set_id=document_set_id,
75+
user_group_id=user_group.id,
76+
)
77+
db_session.add(document_set_relationship)
78+
79+
if credential_ids:
80+
for credential_id in credential_ids:
81+
credential_relationship = Credential__UserGroup(
82+
credential_id=credential_id,
83+
user_group_id=user_group.id,
84+
)
85+
db_session.add(credential_relationship)
86+
87+
db_session.commit()
88+
return user_group
89+
90+
91+
def update_user_group(
92+
db_session: Session,
93+
user_group_id: int,
94+
name: str | None = None,
95+
user_ids: list[str] | None = None,
96+
cc_pair_ids: list[int] | None = None,
97+
persona_ids: list[int] | None = None,
98+
document_set_ids: list[int] | None = None,
99+
credential_ids: list[int] | None = None,
100+
) -> UserGroup | None:
101+
"""Update an existing user group"""
102+
user_group = fetch_user_group_by_id(db_session, user_group_id)
103+
if not user_group:
104+
return None
105+
106+
if name is not None:
107+
user_group.name = name
108+
109+
if user_ids is not None:
110+
db_session.execute(
111+
select(User__UserGroup).where(
112+
User__UserGroup.user_group_id == user_group_id
113+
)
114+
)
115+
for relationship in db_session.scalars(
116+
select(User__UserGroup).where(
117+
User__UserGroup.user_group_id == user_group_id
118+
)
119+
):
120+
db_session.delete(relationship)
121+
122+
for user_id in user_ids:
123+
user_group_relationship = User__UserGroup(
124+
user_group_id=user_group.id,
125+
user_id=user_id,
126+
)
127+
db_session.add(user_group_relationship)
128+
129+
if cc_pair_ids is not None:
130+
for relationship in db_session.scalars(
131+
select(UserGroup__ConnectorCredentialPair).where(
132+
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
133+
)
134+
):
135+
db_session.delete(relationship)
136+
137+
for cc_pair_id in cc_pair_ids:
138+
cc_pair_relationship = UserGroup__ConnectorCredentialPair(
139+
user_group_id=user_group.id,
140+
cc_pair_id=cc_pair_id,
141+
is_current=True,
142+
)
143+
db_session.add(cc_pair_relationship)
144+
145+
if persona_ids is not None:
146+
for relationship in db_session.scalars(
147+
select(Persona__UserGroup).where(
148+
Persona__UserGroup.user_group_id == user_group_id
149+
)
150+
):
151+
db_session.delete(relationship)
152+
153+
for persona_id in persona_ids:
154+
persona_relationship = Persona__UserGroup(
155+
persona_id=persona_id,
156+
user_group_id=user_group.id,
157+
)
158+
db_session.add(persona_relationship)
159+
160+
if document_set_ids is not None:
161+
for relationship in db_session.scalars(
162+
select(DocumentSet__UserGroup).where(
163+
DocumentSet__UserGroup.user_group_id == user_group_id
164+
)
165+
):
166+
db_session.delete(relationship)
167+
168+
for document_set_id in document_set_ids:
169+
document_set_relationship = DocumentSet__UserGroup(
170+
document_set_id=document_set_id,
171+
user_group_id=user_group.id,
172+
)
173+
db_session.add(document_set_relationship)
174+
175+
if credential_ids is not None:
176+
for relationship in db_session.scalars(
177+
select(Credential__UserGroup).where(
178+
Credential__UserGroup.user_group_id == user_group_id
179+
)
180+
):
181+
db_session.delete(relationship)
182+
183+
for credential_id in credential_ids:
184+
credential_relationship = Credential__UserGroup(
185+
credential_id=credential_id,
186+
user_group_id=user_group.id,
187+
)
188+
db_session.add(credential_relationship)
189+
190+
db_session.commit()
191+
return user_group
192+
193+
194+
def delete_user_group(db_session: Session, user_group_id: int) -> bool:
195+
"""Delete a user group and all its relationships"""
196+
user_group = fetch_user_group_by_id(db_session, user_group_id)
197+
if not user_group:
198+
return False
199+
200+
for relationship in db_session.scalars(
201+
select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id)
202+
):
203+
db_session.delete(relationship)
204+
205+
for relationship in db_session.scalars(
206+
select(UserGroup__ConnectorCredentialPair).where(
207+
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
208+
)
209+
):
210+
db_session.delete(relationship)
211+
212+
for relationship in db_session.scalars(
213+
select(Persona__UserGroup).where(
214+
Persona__UserGroup.user_group_id == user_group_id
215+
)
216+
):
217+
db_session.delete(relationship)
218+
219+
for relationship in db_session.scalars(
220+
select(DocumentSet__UserGroup).where(
221+
DocumentSet__UserGroup.user_group_id == user_group_id
222+
)
223+
):
224+
db_session.delete(relationship)
225+
226+
for relationship in db_session.scalars(
227+
select(Credential__UserGroup).where(
228+
Credential__UserGroup.user_group_id == user_group_id
229+
)
230+
):
231+
db_session.delete(relationship)
232+
233+
db_session.delete(user_group)
234+
db_session.commit()
235+
return True
236+
237+
238+
def fetch_user_groups_for_user(db_session: Session, user_id: str) -> list[UserGroup]:
239+
"""Fetch all user groups that a user belongs to"""
240+
return list(
241+
db_session.scalars(
242+
select(UserGroup)
243+
.join(User__UserGroup)
244+
.where(User__UserGroup.user_id == user_id)
245+
).all()
246+
)
247+
248+
249+
def fetch_users_for_user_group(db_session: Session, user_group_id: int) -> list[User]:
250+
"""Fetch all users in a user group"""
251+
return list(
252+
db_session.scalars(
253+
select(User)
254+
.join(User__UserGroup)
255+
.where(User__UserGroup.user_group_id == user_group_id)
256+
).all()
257+
)
258+
259+
260+
def fetch_cc_pairs_for_user_group(
261+
db_session: Session, user_group_id: int
262+
) -> list[ConnectorCredentialPair]:
263+
"""Fetch all connector credential pairs assigned to a user group"""
264+
return list(
265+
db_session.scalars(
266+
select(ConnectorCredentialPair)
267+
.join(UserGroup__ConnectorCredentialPair)
268+
.where(
269+
and_(
270+
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id,
271+
UserGroup__ConnectorCredentialPair.is_current == True,
272+
)
273+
)
274+
).all()
275+
)
276+
277+
278+
def fetch_personas_for_user_group(
279+
db_session: Session, user_group_id: int
280+
) -> list[Persona]:
281+
"""Fetch all personas assigned to a user group"""
282+
return list(
283+
db_session.scalars(
284+
select(Persona)
285+
.join(Persona__UserGroup)
286+
.where(Persona__UserGroup.user_group_id == user_group_id)
287+
).all()
288+
)
289+
290+
291+
def fetch_document_sets_for_user_group(
292+
db_session: Session, user_group_id: int
293+
) -> list[DocumentSet]:
294+
"""Fetch all document sets assigned to a user group"""
295+
return list(
296+
db_session.scalars(
297+
select(DocumentSet)
298+
.join(DocumentSet__UserGroup)
299+
.where(DocumentSet__UserGroup.user_group_id == user_group_id)
300+
).all()
301+
)
302+
303+
304+
def fetch_credentials_for_user_group(
305+
db_session: Session, user_group_id: int
306+
) -> list[Credential]:
307+
"""Fetch all credentials assigned to a user group"""
308+
return list(
309+
db_session.scalars(
310+
select(Credential)
311+
.join(Credential__UserGroup)
312+
.where(Credential__UserGroup.user_group_id == user_group_id)
313+
).all()
314+
)

backend/onyx/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from onyx.server.token_rate_limits.api import (
102102
router as token_rate_limit_settings_router,
103103
)
104+
from onyx.server.user_group.api import router as user_group_router
104105
from onyx.server.user_documents.api import router as user_documents_router
105106
from onyx.server.utils import BasicAuthenticationError
106107
from onyx.setup import setup_multitenant_onyx
@@ -351,6 +352,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
351352
include_router_with_global_prefix_prepended(
352353
application, token_rate_limit_settings_router
353354
)
355+
include_router_with_global_prefix_prepended(application, user_group_router)
354356
include_router_with_global_prefix_prepended(
355357
application, get_full_openai_assistants_api_router()
356358
)

0 commit comments

Comments
 (0)