diff --git a/connectors/sources/outlook.py b/connectors/sources/outlook.py index 47a495162..200a62c4f 100644 --- a/connectors/sources/outlook.py +++ b/connectors/sources/outlook.py @@ -7,6 +7,7 @@ import asyncio import os +from abc import ABC, abstractmethod from copy import copy from datetime import date from functools import cached_property, partial @@ -48,9 +49,12 @@ RETRIES = 3 RETRY_INTERVAL = 2 +WILDCARD = "*" QUEUE_MEM_SIZE = 5 * 1024 * 1024 # Size in Megabytes +GRAPH_API_BATCH_SIZE = 20 # Max batch size supported for fetching users from Graph API + OUTLOOK_SERVER = "outlook_server" OUTLOOK_CLOUD = "outlook_cloud" API_SCOPE = "https://graph.microsoft.com/.default" @@ -212,6 +216,18 @@ class SSLFailed(Exception): pass +class OutlookUserFetchFailed(Exception): + """Exception class to notify that fetching a specific user from Outlook failed.""" + + pass + + +class BatchRequestFailed(Exception): + """Exception class to notify that a batch request to fetch users failed.""" + + pass + + class ManageCertificate: async def store_certificate(self, certificate): async with aiofiles.open(CERT_FILE, "w") as file: @@ -348,13 +364,13 @@ async def get_user_accounts(self): yield user_account -class Office365Users: - """Fetch users from Office365 Active Directory""" +class BaseOffice365User(ABC): + """Abstract base class for Office 365 user management""" def __init__(self, client_id, client_secret, tenant_id): - self.tenant_id = tenant_id self.client_id = client_id self.client_secret = client_secret + self.tenant_id = tenant_id @cached_property def _get_session(self): @@ -403,6 +419,21 @@ async def _fetch_token(self): except Exception as exception: self._check_errors(response=exception) + @abstractmethod + async def get_users(self): + yield + + @abstractmethod + async def get_user_accounts(self): + yield + + +class Office365Users(BaseOffice365User): + """Fetch users from Office365 Active Directory""" + + def __init__(self, client_id, client_secret, tenant_id): + super().__init__(client_id, client_secret, tenant_id) + @retryable( retries=RETRIES, interval=RETRY_INTERVAL, @@ -456,6 +487,78 @@ async def get_user_accounts(self): yield user_account +class MultiOffice365Users(BaseOffice365User): + """Fetch multiple Office365 users based on a list of email addresses.""" + + def __init__(self, client_id, client_secret, tenant_id, client_emails): + super().__init__(client_id, client_secret, tenant_id) + self.client_emails = client_emails + + async def get_users(self): + access_token = await self._fetch_token() + errors = [] + for i in range(0, len(self.client_emails), GRAPH_API_BATCH_SIZE): + batch_emails = self.client_emails[i : i + GRAPH_API_BATCH_SIZE] + requests = [ + {"id": str(index + 1), "method": "GET", "url": f"/users/{email}"} + for index, email in enumerate(batch_emails) + ] + batch_request_body = {"requests": requests} + try: + async with self._get_session.post( + url="https://graph.microsoft.com/v1.0/$batch", + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, + json=batch_request_body, + ) as response: + json_response = await response.json() + for res in json_response.get("responses", []): + user_id = res.get("id") + status = res.get("status") + if status == 200: + yield res.get("body") + else: + msg = f"Error for user {user_id}: {res.get('body')}" + errors.append(OutlookUserFetchFailed(msg)) + except Exception as e: + msg = f"Batch request failed: {str(e)}" + errors.append(BatchRequestFailed(msg)) + + if errors: + msg = "Errors occurred while fetching users: " + "\n".join( + str(e) for e in errors + ) + raise Exception(msg) + + async def get_user_accounts(self): + async for user in self.get_users(): + mail = user.get("mail") + if mail is None: + continue + + credentials = OAuth2Credentials( + client_id=self.client_id, + tenant_id=self.tenant_id, + client_secret=self.client_secret, + identity=Identity(primary_smtp_address=mail), + ) + configuration = Configuration( + credentials=credentials, + auth_type=OAUTH2, + service_endpoint=EWS_ENDPOINT, + retry_policy=FaultTolerance(max_wait=120), + ) + user_account = Account( + primary_smtp_address=mail, + config=configuration, + autodiscover=False, + access_type=IMPERSONATION, + ) + yield user_account + + class OutlookDocFormatter: """Format Outlook object documents to Elasticsearch document""" @@ -583,6 +686,28 @@ def attachment_doc_formatter(self, attachment, attachment_type, timezone): } +class UserFactory: + """Factory class for creating Office365 user instances""" + + @staticmethod + def create_user(configuration: dict) -> BaseOffice365User: + client_emails = configuration.get("client_emails", WILDCARD) + if client_emails == WILDCARD or client_emails == [WILDCARD]: + return Office365Users( + client_id=configuration["client_id"], + client_secret=configuration["client_secret"], + tenant_id=configuration["tenant_id"], + ) + else: + client_emails = [email.strip() for email in client_emails] + return MultiOffice365Users( + client_id=configuration["client_id"], + client_secret=configuration["client_secret"], + tenant_id=configuration["tenant_id"], + client_emails=client_emails, + ) + + class OutlookClient: """Outlook client to handle API calls made to Outlook""" @@ -605,11 +730,7 @@ def set_logger(self, logger_): @cached_property def _get_user_instance(self): if self.is_cloud: - return Office365Users( - client_id=self.configuration["client_id"], - client_secret=self.configuration["client_secret"], - tenant_id=self.configuration["tenant_id"], - ) + return UserFactory.create_user(self.configuration) return ExchangeUsers( ad_server=self.configuration["active_directory_server"], @@ -627,7 +748,8 @@ async def _fetch_all_users(self): yield user async def ping(self): - await anext(self._get_user_instance.get_users()) + async for _user in self._get_user_instance.get_users(): + return async def get_mails(self, account): for mail_type in MAIL_TYPES: @@ -666,9 +788,12 @@ async def get_tasks(self, account): yield task async def get_contacts(self, account): - folder = account.root / "Top of Information Store" / "Contacts" - for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS): - yield contact + try: + folder = account.root / "Top of Information Store" / "Contacts" + for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS): + yield contact + except Exception: + raise class OutlookDataSource(BaseDataSource): @@ -735,37 +860,46 @@ def get_default_configuration(cls): "sensitive": True, "type": "str", }, + "client_emails": { + "depends_on": [{"field": "data_source", "value": OUTLOOK_CLOUD}], + "label": "Client Email Addresses (comma-separated)", + "order": 5, + "tooltip": "Specify the email addresses to limit data fetching to specific clients. If set to *, data will be fetched for all users.", + "required": False, + "type": "list", + "value": "*", + }, "exchange_server": { "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "label": "Exchange Server", - "order": 5, + "order": 6, "tooltip": "Exchange server's IP address. E.g. 127.0.0.1", "type": "str", }, "active_directory_server": { "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "label": "Active Directory Server", - "order": 6, + "order": 7, "tooltip": "Active Directory server's IP address. E.g. 127.0.0.1", "type": "str", }, "username": { "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "label": "Exchange server username", - "order": 7, + "order": 8, "type": "str", }, "password": { "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "label": "Exchange server password", - "order": 8, + "order": 9, "sensitive": True, "type": "str", }, "domain": { "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "label": "Exchange server domain name", - "order": 9, + "order": 10, "tooltip": "Domain name such as gmail.com, outlook.com", "type": "str", }, @@ -773,7 +907,7 @@ def get_default_configuration(cls): "depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}], "display": "toggle", "label": "Enable SSL", - "order": 10, + "order": 11, "type": "bool", "value": False, }, @@ -783,13 +917,13 @@ def get_default_configuration(cls): {"field": "ssl_enabled", "value": True}, ], "label": "SSL certificate", - "order": 11, + "order": 12, "type": "str", }, "use_text_extraction_service": { "display": "toggle", "label": "Use text extraction service", - "order": 12, + "order": 13, "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", "type": "bool", "ui_restrictions": ["advanced"], @@ -798,7 +932,7 @@ def get_default_configuration(cls): "use_document_level_security": { "display": "toggle", "label": "Enable document level security", - "order": 13, + "order": 14, "tooltip": "Document level security ensures identities and permissions set in Outlook are maintained in Elasticsearch. This enables you to restrict and personalize read-access users and groups have to documents in this index. Access control syncs ensure this metadata is kept up to date in your Elasticsearch documents.", "type": "bool", "value": False, @@ -1072,9 +1206,11 @@ async def get_docs(self, filtering=None): dictionary: dictionary containing meta-data of the files. """ async for account in self.client._get_user_instance.get_user_accounts(): + self._logger.debug(f"Processing account: {account}") timezone = account.default_timezone or DEFAULT_TIMEZONE async for mail in self._fetch_mails(account=account, timezone=timezone): + self._logger.debug(f"Fetched mail: {mail}") yield mail async for contact in self._fetch_contacts( diff --git a/tests/sources/test_outlook.py b/tests/sources/test_outlook.py index 31cc507c2..0edc7d4ad 100644 --- a/tests/sources/test_outlook.py +++ b/tests/sources/test_outlook.py @@ -14,8 +14,10 @@ from connectors.source import ConfigurableFieldValueError from connectors.sources.outlook import ( + GRAPH_API_BATCH_SIZE, OUTLOOK_CLOUD, OUTLOOK_SERVER, + WILDCARD, Forbidden, NotFound, OutlookDataSource, @@ -374,6 +376,7 @@ async def create_outlook_source( tenant_id="foo", client_id="bar", client_secret="faa", + client_emails=None, exchange_server="127.0.0.1", active_directory_server="127.0.0.1", username="fee", @@ -383,12 +386,16 @@ async def create_outlook_source( ssl_ca="", use_text_extraction_service=False, ): + if client_emails is None: + client_emails = WILDCARD + async with create_source( OutlookDataSource, data_source=data_source, tenant_id=tenant_id, client_id=client_id, client_secret=client_secret, + client_emails=client_emails, exchange_server=exchange_server, active_directory_server=active_directory_server, username=username, @@ -415,26 +422,71 @@ def get_stream_reader(): return async_mock -def side_effect_function(url, headers): +def side_effect_function(client_emails=None): """Dynamically changing return values for API calls Args: - url, ssl: Params required for get call + client_emails: Optional string of comma-separated email addresses """ - if url == "https://graph.microsoft.com/v1.0/users?$top=999": - return get_json_mock( - mock_response={ - "@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token", - "value": [{"mail": "test.user@gmail.com"}], - }, - status=200, - ) - elif ( - url - == "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token" - ): - return get_json_mock( - mock_response={"value": [{"mail": "dummy.user@gmail.com"}]}, status=200 - ) + email_counter = 0 + + def inner(url, headers=None, json=None, data=None): + nonlocal email_counter + + if "oauth2/v2.0/token" in url and data: + return get_json_mock( + mock_response={"access_token": "fake-token"}, + status=200, + ) + if url == "https://graph.microsoft.com/v1.0/$batch" and json: + batch_requests = json.get("requests", [])[:GRAPH_API_BATCH_SIZE] + + if client_emails: + email_list = client_emails.split(",") + + responses = [] + for request in batch_requests: + if email_counter < len(email_list): + responses.append( + { + "id": request.get("id"), + "status": 200, + "body": { + "value": [{"mail": email_list[email_counter]}] + }, + } + ) + email_counter += 1 + else: + break + else: + responses = [ + { + "id": request.get("id"), + "status": 200, + "body": {"value": [{"mail": f"user{email_counter}@test.com"}]}, + } + for request in batch_requests + ] + email_counter += len(batch_requests) + + return get_json_mock(mock_response={"responses": responses}, status=200) + elif url == "https://graph.microsoft.com/v1.0/users?$top=999": + return get_json_mock( + mock_response={ + "@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token", + "value": [{"mail": "test.user@gmail.com"}], + }, + status=200, + ) + elif ( + url + == "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token" + ): + return get_json_mock( + mock_response={"value": [{"mail": "dummy.user@gmail.com"}]}, status=200 + ) + + return inner @pytest.mark.asyncio @@ -459,6 +511,7 @@ def side_effect_function(url, headers): "tenant_id": "foo", "client_id": "bar", "client_secret": "", + "client_emails": WILDCARD, } ), ], @@ -497,6 +550,17 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro "tenant_id": "foo", "client_id": "bar", "client_secret": "foo.bar", + "client_emails": WILDCARD, + } + ), + ( + # Outlook Cloud with non-blank dependent fields & client_emails provided + { + "data_source": OUTLOOK_CLOUD, + "tenant_id": "foo", + "client_id": "bar", + "client_secret": "foo.bar", + "client_emails": "test.user@gmail.com", } ), ], @@ -552,7 +616,7 @@ async def test_ping_for_cloud(): ): with mock.patch( "aiohttp.ClientSession.get", - side_effect=side_effect_function, + side_effect=side_effect_function(), ): await source.ping() @@ -597,7 +661,7 @@ async def test_get_users_for_cloud(): ): with mock.patch( "aiohttp.ClientSession.get", - side_effect=side_effect_function, + side_effect=side_effect_function(), ): async for response in source.client._get_user_instance.get_users(): user_mails = [user["mail"] for user in response["value"]] @@ -605,6 +669,22 @@ async def test_get_users_for_cloud(): assert users == ["test.user@gmail.com", "dummy.user@gmail.com"] +@pytest.mark.asyncio +async def test_get_users_for_cloud_with_client_emails(): + client_emails = ",".join([f"test.user{i}@gmail.com" for i in range(25)]) + async with create_outlook_source(client_emails=client_emails) as source: + users = [] + with mock.patch( + "aiohttp.ClientSession.post", + side_effect=side_effect_function(client_emails), + ): + async for response in source.client._get_user_instance.get_users(): + user_mails = [user["mail"] for user in response["value"]] + users.extend(user_mails) + assert users == list(client_emails.split(",")) + assert len(users) == 25 + + @pytest.mark.asyncio @patch("connectors.sources.outlook.Connection") async def test_fetch_admin_users_negative(mock_connection):