Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import re
from collections import deque
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse

from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
Expand Down Expand Up @@ -231,6 +234,7 @@ def process_users(users: list[Any]) -> None:
nonlocal groups, user_emails

for user in users:
logger.debug(f"User: {user.to_json()}")
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
Expand Down Expand Up @@ -285,7 +289,7 @@ def process_members(members: list[Any]) -> None:

for member in members:
member_data = member.to_json()

logger.debug(f"Member: {member_data}")
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
Expand Down Expand Up @@ -366,13 +370,15 @@ def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
is_group_sync: bool = False,
) -> GroupsResult:
"""
Get all groups and their members recursively.
"""
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
found_public_group = False
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
Expand All @@ -390,19 +396,35 @@ def _get_groups_and_members_recursively(
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
return GroupsResult(groups_to_emails={}, found_public_group=True)

group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
try:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
found_public_group = True
if not is_group_sync:
return GroupsResult(
groups_to_emails={}, found_public_group=True
)
else:
# we don't want to sync public groups, so we skip them
continue
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
except ClientRequestException as e:
# If the group is not found, we skip it. There is a chance that group is still referenced
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
# our testing we have seen this happen.
if e.response is not None and e.response.status_code == 404:
logger.warning(f"Group {group.login_name} not found")
continue
raise e

return GroupsResult(
groups_to_emails=visited_group_name_to_emails, found_public_group=False
groups_to_emails=visited_group_name_to_emails,
found_public_group=found_public_group,
)


Expand All @@ -427,6 +449,7 @@ def add_user_and_group_to_sets(
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
logger.debug(f"Assignment: {assignment.to_json()}")
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
Expand Down Expand Up @@ -503,12 +526,19 @@ def add_user_and_group_to_sets(
)
elif site_page:
site_url = site_page.get("webUrl")
site_pages = client_context.web.lists.get_by_title("Site Pages")
client_context.load(site_pages)
client_context.execute_query()
site_pages.items.get_by_url(site_url).role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
item = file_obj.listItemAllFields

sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
else:
raise RuntimeError("No drive item or site page provided")

Expand Down Expand Up @@ -595,13 +625,9 @@ def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
client_context, graph_client, groups, is_group_sync=True
)

# We don't have any direct way to check if the site is public, so we check if any public group is present
if groups_and_members.found_public_group:
return []

# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
Expand Down
4 changes: 2 additions & 2 deletions backend/onyx/connectors/sharepoint/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def sleep_and_retry(
return query_obj.execute_query()
except ClientRequestException as e:
if (
e.response
e.response is not None
and e.response.status_code in [429, 503]
and attempt < max_retries
):
Expand All @@ -119,7 +119,7 @@ def sleep_and_retry(
time.sleep(sleep_time)
else:
# Either not a rate limit error, or we've exhausted retries
if e.response and e.response.status_code == 429:
if e.response is not None and e.response.status_code == 429:
logger.error(
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
)
Expand Down
Loading