Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/onyx/access/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def truncate_set(s: set[str], max_len: int = 100) -> str:
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)

@classmethod
def public(cls) -> "ExternalAccess":
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)

@classmethod
def empty(cls) -> "ExternalAccess":
"""
Expand Down
43 changes: 9 additions & 34 deletions backend/onyx/connectors/teams/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.team import Team # type: ignore

from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
Expand All @@ -33,6 +32,7 @@
from onyx.connectors.models import TextSection
from onyx.connectors.teams.models import Message
from onyx.connectors.teams.utils import fetch_expert_infos
from onyx.connectors.teams.utils import fetch_external_access
from onyx.connectors.teams.utils import fetch_messages
from onyx.connectors.teams.utils import fetch_replies
from onyx.file_processing.html_utils import parse_html_page_basic
Expand All @@ -43,7 +43,6 @@
logger = setup_logger()

_SLIM_DOC_BATCH_SIZE = 5000
_PUBLIC_MEMBERSHIP_TYPE = "standard" # public teams channel


class TeamsCheckpoint(ConnectorCheckpoint):
Expand Down Expand Up @@ -260,18 +259,8 @@ def retrieve_all_slim_documents(
)
continue

is_public = _is_channel_public(channel=channel)
expert_infos = (
set()
if is_public
else fetch_expert_infos(
graph_client=self.graph_client, channel=channel
)
)
external_user_emails = set(
expert_info.email
for expert_info in expert_infos
if expert_info.email
external_access = fetch_external_access(
graph_client=self.graph_client, channel=channel
)

messages = fetch_messages(
Expand All @@ -287,11 +276,7 @@ def retrieve_all_slim_documents(
slim_doc_buffer.append(
SlimDocument(
id=message.id,
external_access=ExternalAccess(
external_user_emails=external_user_emails,
external_user_group_ids=set(),
is_public=is_public,
),
external_access=external_access,
)
)

Expand Down Expand Up @@ -336,9 +321,6 @@ def _convert_thread_to_document(
if len(thread) == 0:
return None

expert_infos = fetch_expert_infos(graph_client=graph_client, channel=channel)
emails = set(expert_info.email for expert_info in expert_infos if expert_info.email)

most_recent_message_datetime: datetime | None = None
top_message = thread[0]
thread_text = ""
Expand All @@ -361,7 +343,10 @@ def _convert_thread_to_document(
return None

semantic_string = _construct_semantic_identifier(channel, top_message)
is_public = _is_channel_public(channel=channel)
expert_infos = fetch_expert_infos(graph_client=graph_client, channel=channel)
external_access = fetch_external_access(
graph_client=graph_client, channel=channel, expert_infos=expert_infos
)

return Document(
id=top_message.id,
Expand All @@ -372,11 +357,7 @@ def _convert_thread_to_document(
doc_updated_at=most_recent_message_datetime,
primary_owners=expert_infos,
metadata={},
external_access=ExternalAccess(
external_user_emails=emails,
external_user_group_ids=set(),
is_public=is_public,
),
external_access=external_access,
)


Expand Down Expand Up @@ -558,12 +539,6 @@ def _collect_documents_for_channel(
)


def _is_channel_public(channel: Channel) -> bool:
return (
channel.membership_type and channel.membership_type == _PUBLIC_MEMBERSHIP_TYPE
)


if __name__ == "__main__":
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector

Expand Down
78 changes: 56 additions & 22 deletions backend/onyx/connectors/teams/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.channels.channel import ConversationMember # type: ignore

from onyx.access.models import ExternalAccess
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.teams.models import Message
Expand All @@ -16,6 +17,9 @@
logger = setup_logger()


_PUBLIC_MEMBERSHIP_TYPE = "standard" # public teams channel


def _retry(
graph_client: GraphClient,
request_url: str,
Expand Down Expand Up @@ -64,6 +68,34 @@ def _get_next_url(
return next_url.removeprefix(graph_client.service_root_url()).removeprefix("/")


def _get_or_fetch_email(
graph_client: GraphClient,
member: ConversationMember,
) -> str | None:
if email := member.properties.get("email"):
return email

user_id = member.properties.get("userId")
if not user_id:
logger.warn(f"No user-id found for this member; {member=}")
return None

json_data = _retry(graph_client=graph_client, request_url=f"users/{user_id}")
email = json_data.get("userPrincipalName")

if not isinstance(email, str):
logger.warn(f"Expected email to be of type str, instead got {email=}")
return None

return email


def _is_channel_public(channel: Channel) -> bool:
return (
channel.membership_type and channel.membership_type == _PUBLIC_MEMBERSHIP_TYPE
)


def fetch_messages(
graph_client: GraphClient,
team_id: str,
Expand Down Expand Up @@ -115,28 +147,6 @@ def fetch_replies(
)


def _get_or_fetch_email(
graph_client: GraphClient,
member: ConversationMember,
) -> str | None:
if email := member.properties.get("email"):
return email

user_id = member.properties.get("userId")
if not user_id:
logger.warn(f"No user-id found for this member; {member=}")
return None

json_data = _retry(graph_client=graph_client, request_url=f"users/{user_id}")
email = json_data.get("userPrincipalName")

if not isinstance(email, str):
logger.warn(f"Expected email to be of type str, instead got {email=}")
return None

return email


def fetch_expert_infos(
graph_client: GraphClient, channel: Channel
) -> list[BasicExpertInfo]:
Expand Down Expand Up @@ -164,3 +174,27 @@ def fetch_expert_infos(
)

return expert_infos


def fetch_external_access(
graph_client: GraphClient,
channel: Channel,
expert_infos: list[BasicExpertInfo] | None = None,
) -> ExternalAccess:
is_public = _is_channel_public(channel=channel)

if is_public:
return ExternalAccess.public()

expert_infos = (
expert_infos
if expert_infos
else fetch_expert_infos(graph_client=graph_client, channel=channel)
)
emails = {expert_info.email for expert_info in expert_infos if expert_info.email}

return ExternalAccess(
external_user_emails=emails,
external_user_group_ids=set(),
is_public=is_public,
)
Loading