From 0ba04aebf0d3201be8afb67d0abc305636770806 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sun, 27 Oct 2024 08:31:15 -0700 Subject: [PATCH 01/23] refactoring changes --- backend/danswer/configs/app_configs.py | 3 -- .../connectors/confluence/connector.py | 7 ++- .../connectors/google_drive/connector_auth.py | 50 ++++++++++++++++--- .../connectors/google_drive/constants.py | 1 - backend/danswer/connectors/interfaces.py | 6 ++- .../connectors/salesforce/connector.py | 6 ++- backend/danswer/connectors/slack/connector.py | 6 ++- backend/danswer/db/credentials.py | 6 +-- 8 files changed, 65 insertions(+), 20 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index caf7a103b94..25b35838982 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -249,9 +249,6 @@ # for some connectors ENABLE_EXPENSIVE_EXPERT_CALLS = False -GOOGLE_DRIVE_INCLUDE_SHARED = False -GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False -GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False # TODO these should be available for frontend configuration, via advanced options expandable WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get( diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index fe52862982d..d3cbe25c3d6 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -17,6 +17,7 @@ from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SlimConnector from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import ConnectorMissingCredentialError @@ -247,7 +248,11 @@ def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput: self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'" return self._fetch_document_batches() - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.confluence_client is None: raise ConnectorMissingCredentialError("Confluence") diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 777deae990a..464e59cd798 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -20,9 +20,6 @@ from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, -) from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES @@ -84,6 +81,45 @@ def _get_google_drive_creds_for_service_account( return creds if creds.valid else None +def get_service_account_credentials( + credentials: dict[str, str], + scopes: list[str] = build_gdrive_scopes(), +) -> ServiceAccountCredentials: + service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] + service_creds = _get_google_drive_creds_for_service_account( + service_account_key_json_str=service_account_key_json_str, + scopes=scopes, + ) + + # "Impersonate" a user if one is specified + delegated_user_email = cast( + str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) + ) + if delegated_user_email: + service_creds = ( + service_creds.with_subject(delegated_user_email) if service_creds else None + ) + return service_creds + + +def get_oauth_credentials( + credentials: dict[str, str], + scopes: list[str] = build_gdrive_scopes(), +) -> tuple[OAuthCredentials | None, dict[str, str] | None]: + new_creds_dict = None + access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) + oauth_creds = get_google_drive_creds_for_authorized_user( + token_json_str=access_token_json_str, scopes=scopes + ) + + # tell caller to update token stored in DB if it has changed + # (e.g. the token has been refreshed) + new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" + if new_creds_json_str != access_token_json_str: + new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} + return oauth_creds, new_creds_dict + + def get_google_drive_creds( credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes() ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: @@ -102,10 +138,8 @@ def get_google_drive_creds( if new_creds_json_str != access_token_json_str: new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} - elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: - service_account_key_json_str = credentials[ - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - ] + elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: + service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] service_creds = _get_google_drive_creds_for_service_account( service_account_key_json_str=service_account_key_json_str, scopes=scopes, @@ -189,7 +223,7 @@ def build_service_account_creds( service_account_key = get_service_account_key() credential_dict = { - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(), + KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(), } if delegated_user_email: credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 0cca65c13df..563f2c63b47 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -1,5 +1,4 @@ DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" -DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key" DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index 4734212147e..c53b3de5f2f 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -56,7 +56,11 @@ def poll_source( class SlimConnector(BaseConnector): @abc.abstractmethod - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: raise NotImplementedError diff --git a/backend/danswer/connectors/salesforce/connector.py b/backend/danswer/connectors/salesforce/connector.py index 78d73d44766..1e0fe9e1d3a 100644 --- a/backend/danswer/connectors/salesforce/connector.py +++ b/backend/danswer/connectors/salesforce/connector.py @@ -251,7 +251,11 @@ def poll_source( end_datetime = datetime.utcfromtimestamp(end) return self._fetch_from_salesforce(start=start_datetime, end=end_datetime) - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.sf_client is None: raise ConnectorMissingCredentialError("Salesforce") doc_metadata_list: list[SlimDocument] = [] diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index f5728950e4f..ff92f361f4e 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -391,7 +391,11 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None self.client = WebClient(token=bot_token) return None - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.client is None: raise ConnectorMissingCredentialError("Slack") diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 5da5099f1e3..71dd569ac66 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -10,12 +10,10 @@ from danswer.auth.schemas import UserRole from danswer.configs.constants import DocumentSource +from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.gmail.constants import ( GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, -) from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential from danswer.db.models import Credential__UserGroup @@ -441,7 +439,7 @@ def delete_google_drive_service_account_credentials( ) -> None: credentials = fetch_credentials(db_session=db_session, user=user) for credential in credentials: - if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY): + if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY): db_session.delete(credential) db_session.commit() From 121ea2e83ab03be085f4fa5e97604f04a31c87b7 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sun, 27 Oct 2024 12:36:51 -0700 Subject: [PATCH 02/23] everything working for service account --- .../connectors/google_drive/connector.py | 738 ++++++++---------- .../google_drive/doc_sync.py | 190 +++-- 2 files changed, 420 insertions(+), 508 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 9f8c6fbfda8..7d2697049fe 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -1,44 +1,41 @@ import io +from collections.abc import Callable from collections.abc import Iterator -from collections.abc import Sequence from datetime import datetime from datetime import timezone from enum import Enum -from itertools import chain from typing import Any from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore -from googleapiclient import discovery # type: ignore +from googleapiclient.discovery import build # type: ignore +from googleapiclient.discovery import Resource # type: ignore from googleapiclient.errors import HttpError # type: ignore from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE -from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS -from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED -from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, -) from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.interfaces import SlimConnector from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.connectors.models import SlimDocument from danswer.file_processing.extract_file_text import docx_to_text from danswer.file_processing.extract_file_text import pptx_to_text from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.unstructured import get_unstructured_api_key from danswer.file_processing.unstructured import unstructured_to_text -from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger logger = setup_logger() @@ -47,6 +44,23 @@ DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now +FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" +SLIM_FILE_FIELDS = ( + "nextPageToken, files(permissions(emailAddress, type), webViewLink), permissionIds" +) +FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" + +# these errors don't represent a failure in the connector, but simply files +# that can't / shouldn't be indexed +ERRORS_TO_CONTINUE_ON = [ + "cannotExportFile", + "exportSizeLimitExceeded", + "cannotDownloadFile", +] +_SLIM_BATCH_SIZE = 500 + +_TRAVERSED_PARENT_IDS: set[str] = set() + class GDriveMimeType(str, Enum): DOC = "application/vnd.google-apps.document" @@ -69,239 +83,31 @@ class GDriveMimeType(str, Enum): add_retries = retry_builder(tries=50, max_delay=30) -def _run_drive_file_query( - service: discovery.Resource, - query: str, - continue_on_failure: bool, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, +def _execute_paginated_retrieval( + retrieval_function: Callable[..., Any], + list_key: str, + **kwargs: Any, ) -> Iterator[GoogleDriveFileType]: + """Execute a paginated retrieval from Google Drive API + Args: + retrieval_function: The specific list function to call (e.g., service.files().list) + **kwargs: Arguments to pass to the list function + """ + print("\n -------------------------------") next_page_token = "" while next_page_token is not None: - logger.debug(f"Running Google Drive fetch with query: {query}") - results = add_retries( - lambda: ( - service.files() - .list( - corpora="allDrives" - if include_shared - else "user", # needed to search through shared drives - pageSize=batch_size, - supportsAllDrives=include_shared, - includeItemsFromAllDrives=include_shared, - fields=( - "nextPageToken, files(mimeType, id, name, permissions, " - "modifiedTime, webViewLink, shortcutDetails)" - ), - pageToken=next_page_token, - q=query, - ) - .execute() - ) - )() - next_page_token = results.get("nextPageToken") - files = results["files"] - for file in files: - if follow_shortcuts and "shortcutDetails" in file: - try: - file_shortcut_points_to = add_retries( - lambda: ( - service.files() - .get( - fileId=file["shortcutDetails"]["targetId"], - supportsAllDrives=include_shared, - fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails", - ) - .execute() - ) - )() - yield file_shortcut_points_to - except HttpError: - logger.error( - f"Failed to follow shortcut with details: {file['shortcutDetails']}" - ) - if continue_on_failure: - continue - raise - else: - yield file - - -def _get_folder_id( - service: discovery.Resource, - parent_id: str, - folder_name: str, - include_shared: bool, - follow_shortcuts: bool, -) -> str | None: - """ - Get the ID of a folder given its name and the ID of its parent folder. - """ - query = f"'{parent_id}' in parents and name='{folder_name}' and " - if follow_shortcuts: - query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')" - else: - query += f"mimeType='{DRIVE_FOLDER_TYPE}'" - - # TODO: support specifying folder path in shared drive rather than just `My Drive` - results = add_retries( - lambda: ( - service.files() - .list( - q=query, - spaces="drive", - fields="nextPageToken, files(id, name, shortcutDetails)", - supportsAllDrives=include_shared, - includeItemsFromAllDrives=include_shared, - ) - .execute() - ) - )() - items = results.get("files", []) - - folder_id = None - if items: - if follow_shortcuts and "shortcutDetails" in items[0]: - folder_id = items[0]["shortcutDetails"]["targetId"] - else: - folder_id = items[0]["id"] - return folder_id - - -def _get_folders( - service: discovery.Resource, - continue_on_failure: bool, - folder_id: str | None = None, # if specified, only fetches files within this folder - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, -) -> Iterator[GoogleDriveFileType]: - query = f"mimeType = '{DRIVE_FOLDER_TYPE}' " - if follow_shortcuts: - query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") " - - if folder_id: - query += f"and '{folder_id}' in parents " - query = query.rstrip() # remove the trailing space(s) - - for file in _run_drive_file_query( - service=service, - query=query, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ): - # Need to check this since file may have been a target of a shortcut - # and not necessarily a folder - if file["mimeType"] == DRIVE_FOLDER_TYPE: - yield file - else: - pass - - -def _get_files( - service: discovery.Resource, - continue_on_failure: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - folder_id: str | None = None, # if specified, only fetches files within this folder - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, -) -> Iterator[GoogleDriveFileType]: - query = f"mimeType != '{DRIVE_FOLDER_TYPE}' " - if time_range_start is not None: - time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" - query += f"and modifiedTime >= '{time_start}' " - if time_range_end is not None: - time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" - query += f"and modifiedTime <= '{time_stop}' " - if folder_id: - query += f"and '{folder_id}' in parents " - query = query.rstrip() # remove the trailing space(s) - - files = _run_drive_file_query( - service=service, - query=query, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) + request_kwargs = kwargs.copy() + if next_page_token: + request_kwargs["pageToken"] = next_page_token - return files - - -def get_all_files_batched( - service: discovery.Resource, - continue_on_failure: bool, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - folder_id: str | None = None, # if specified, only fetches files within this folder - # if True, will fetch files in sub-folders of the specified folder ID. - # Only applies if folder_id is specified. - traverse_subfolders: bool = True, - folder_ids_traversed: list[str] | None = None, -) -> Iterator[list[GoogleDriveFileType]]: - """Gets all files matching the criteria specified by the args from Google Drive - in batches of size `batch_size`. - """ - found_files = _get_files( - service=service, - continue_on_failure=continue_on_failure, - time_range_start=time_range_start, - time_range_end=time_range_end, - folder_id=folder_id, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) - yield from batch_generator( - items=found_files, - batch_size=batch_size, - pre_batch_yield=lambda batch_files: logger.debug( - f"Parseable Documents in batch: {[file['name'] for file in batch_files]}" - ), - ) + results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() - if traverse_subfolders and folder_id is not None: - folder_ids_traversed = folder_ids_traversed or [] - subfolders = _get_folders( - service=service, - folder_id=folder_id, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) - for subfolder in subfolders: - if subfolder["id"] not in folder_ids_traversed: - logger.info("Fetching all files in subfolder: " + subfolder["name"]) - folder_ids_traversed.append(subfolder["id"]) - yield from get_all_files_batched( - service=service, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - time_range_start=time_range_start, - time_range_end=time_range_end, - folder_id=subfolder["id"], - traverse_subfolders=traverse_subfolders, - folder_ids_traversed=folder_ids_traversed, - ) - else: - logger.debug( - "Skipping subfolder since already traversed: " + subfolder["name"] - ) + next_page_token = results.get("nextPageToken") + for item in results.get(list_key, []): + yield item -def extract_text(file: dict[str, str], service: discovery.Resource) -> str: +def extract_text(file: dict[str, str], service: Resource) -> str: mime_type = file["mimeType"] if mime_type not in set(item.value for item in GDriveMimeType): @@ -351,57 +157,71 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str: return UNSUPPORTED_FILE_TYPE_CONTENT -class GoogleDriveConnector(LoadConnector, PollConnector): +def _convert_drive_item_to_document( + file: GoogleDriveFileType, service: Resource +) -> Document | None: + try: + # Skip files that are shortcuts + if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: + logger.info("Ignoring Drive Shortcut Filetype") + return None + try: + text_contents = extract_text(file, service) or "" + except HttpError as e: + reason = e.error_details[0]["reason"] if e.error_details else e.reason + message = e.error_details[0]["message"] if e.error_details else e.reason + if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + logger.warning( + f"Could not export file '{file['name']}' due to '{message}', skipping..." + ) + return None + + raise + + return Document( + id=file["webViewLink"], + sections=[Section(link=file["webViewLink"], text=text_contents)], + source=DocumentSource.GOOGLE_DRIVE, + semantic_identifier=file["name"], + doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( + timezone.utc + ), + metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + additional_info=file.get("id"), + ) + except Exception as e: + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise e + + logger.exception("Ran into exception when pulling a file from Google Drive") + return None + + +def _extract_parent_ids_from_urls(urls: list[str]) -> list[str]: + return [url.split("/")[-1] for url in urls] + + +class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, - # optional list of folder paths e.g. "[My Folder/My Subfolder]" - # if specified, will only index files in these folders - folder_paths: list[str] | None = None, + parent_urls: list[str] | None = None, + include_personal: bool = True, batch_size: int = INDEX_BATCH_SIZE, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC, - continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, ) -> None: - self.folder_paths = folder_paths or [] self.batch_size = batch_size - self.include_shared = include_shared - self.follow_shortcuts = follow_shortcuts - self.only_org_public = only_org_public - self.continue_on_failure = continue_on_failure - self.creds: OAuthCredentials | ServiceAccountCredentials | None = None - - @staticmethod - def _process_folder_paths( - service: discovery.Resource, - folder_paths: list[str], - include_shared: bool, - follow_shortcuts: bool, - ) -> list[str]: - """['Folder/Sub Folder'] -> ['']""" - folder_ids: list[str] = [] - for path in folder_paths: - folder_names = path.split("/") - parent_id = "root" - for folder_name in folder_names: - found_parent_id = _get_folder_id( - service=service, - parent_id=parent_id, - folder_name=folder_name, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - ) - if found_parent_id is None: - raise ValueError( - ( - f"Folder '{folder_name}' in path '{path}' " - "not found in Google Drive" - ) - ) - parent_id = found_parent_id - folder_ids.append(parent_id) - return folder_ids + self.parent_ids = ( + _extract_parent_ids_from_urls(parent_urls) if parent_urls else [] + ) + self.include_personal = include_personal + + self.service_account_email: str | None = None + self.service_account_domain: str | None = None + self.service_account_creds: ServiceAccountCredentials | None = None + + self.oauth_creds: OAuthCredentials | None = None + + self.is_slim: bool = False def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: """Checks for two different types of credentials. @@ -410,147 +230,253 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. """ + creds, new_creds_dict = get_google_drive_creds(credentials) - self.creds = creds + if KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: + self.service_account_creds = creds + self.service_account_email = credentials[ + DB_CREDENTIALS_DICT_DELEGATED_USER_KEY + ] + if self.service_account_email: + self.service_account_domain = self.service_account_email.split("@")[1] + else: + self.oauth_creds = creds return new_creds_dict - def _fetch_docs_from_drive( + def _get_folders_in_parent( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, - ) -> GenerateDocumentsOutput: - if self.creds is None: - raise PermissionError("Not logged into Google Drive") + service: Resource, + parent_id: str | None = None, + personal_drive: bool = False, + ) -> Iterator[GoogleDriveFileType]: + # Follow shortcuts to folders + query = ( + f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" + ) - service = discovery.build("drive", "v3", credentials=self.creds) - folder_ids: Sequence[str | None] = self._process_folder_paths( - service, self.folder_paths, self.include_shared, self.follow_shortcuts + if parent_id: + query += f" and '{parent_id}' in parents" + + for file in _execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "domain", + supportsAllDrives=personal_drive, + includeItemsFromAllDrives=personal_drive, + fields=FOLDER_FIELDS, + q=query, + ): + yield file + + def _get_files_in_parent( + self, + service: Resource, + parent_id: str, + personal_drive: bool, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" + if time_range_start is not None: + time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" + query += f" and modifiedTime >= '{time_start}'" + if time_range_end is not None: + time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" + query += f" and modifiedTime <= '{time_stop}'" + + for file in _execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "domain", + supportsAllDrives=True, + includeItemsFromAllDrives=True, + fields=SLIM_FILE_FIELDS if self.is_slim else FILE_FIELDS, + q=query, + ): + yield file + + def _crawl_drive_for_files( + self, + service: Resource, + parent_id: str, + personal_drive: bool, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[GoogleDriveFileType]: + """Gets all files matching the criteria specified by the args from Google Drive + in batches of size `batch_size`. + """ + if parent_id in _TRAVERSED_PARENT_IDS: + logger.debug(f"Skipping subfolder since already traversed: {parent_id}") + return + + _TRAVERSED_PARENT_IDS.add(parent_id) + + yield from self._get_files_in_parent( + service=service, + personal_drive=personal_drive, + time_range_start=time_range_start, + time_range_end=time_range_end, + parent_id=parent_id, ) - if not folder_ids: - folder_ids = [None] - - file_batches = chain( - *[ - get_all_files_batched( - service=service, - continue_on_failure=self.continue_on_failure, - include_shared=self.include_shared, - follow_shortcuts=self.follow_shortcuts, - batch_size=self.batch_size, - time_range_start=start, - time_range_end=end, - folder_id=folder_id, - traverse_subfolders=True, - ) - for folder_id in folder_ids - ] + + for subfolder in self._get_folders_in_parent( + service=service, + parent_id=parent_id, + personal_drive=personal_drive, + ): + logger.info("Fetching all files in subfolder: " + subfolder["name"]) + yield from self._crawl_drive_for_files( + service=service, + parent_id=subfolder["id"], + personal_drive=personal_drive, + time_range_start=time_range_start, + time_range_end=time_range_end, + ) + + def _get_all_user_emails(self) -> list[str]: + if not self.service_account_creds: + raise PermissionError("No service account credentials found") + + admin_creds = self.service_account_creds.with_subject( + self.service_account_email ) - for files_batch in file_batches: - doc_batch = [] - for file in files_batch: - try: - # Skip files that are shortcuts - if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: - logger.info("Ignoring Drive Shortcut Filetype") - continue - - if self.only_org_public: - if "permissions" not in file: - continue - if not any( - permission["type"] == "domain" - for permission in file["permissions"] - ): - continue - try: - text_contents = extract_text(file, service) or "" - except HttpError as e: - reason = ( - e.error_details[0]["reason"] - if e.error_details - else e.reason - ) - message = ( - e.error_details[0]["message"] - if e.error_details - else e.reason - ) - - # these errors don't represent a failure in the connector, but simply files - # that can't / shouldn't be indexed - ERRORS_TO_CONTINUE_ON = [ - "cannotExportFile", - "exportSizeLimitExceeded", - "cannotDownloadFile", - ] - if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: - logger.warning( - f"Could not export file '{file['name']}' due to '{message}', skipping..." - ) - continue - - raise - - doc_batch.append( - Document( - id=file["webViewLink"], - sections=[ - Section(link=file["webViewLink"], text=text_contents) - ], - source=DocumentSource.GOOGLE_DRIVE, - semantic_identifier=file["name"], - doc_updated_at=datetime.fromisoformat( - file["modifiedTime"] - ).astimezone(timezone.utc), - metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, - additional_info=file.get("id"), - ) + admin_service = build("admin", "directory_v1", credentials=admin_creds) + emails = [] + for user in _execute_paginated_retrieval( + retrieval_function=admin_service.users().list, + list_key="users", + domain=self.service_account_domain, + ): + if email := user.get("primaryEmail"): + emails.append(email) + return emails + + def _fetch_drive_items( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[GoogleDriveFileType]: + # admin_creds = self.service_account_creds.with_subject(self.service_account_email) + admin_creds = self.get_primary_user_credentials() + admin_drive_service = build("drive", "v3", credentials=admin_creds) + + parent_ids = self.parent_ids + if not parent_ids: + # if no parent ids are specified, get all shared drives using the admin account + for drive in _execute_paginated_retrieval( + retrieval_function=admin_drive_service.drives().list, + list_key="drives", + useDomainAdminAccess=True, + fields="drives(id)", + ): + parent_ids.append(drive["id"]) + + # crawl all the shared parent ids for files + for parent_id in parent_ids: + yield from self._crawl_drive_for_files( + service=admin_drive_service, + parent_id=parent_id, + personal_drive=False, + time_range_start=start, + time_range_end=end, + ) + + # get all personal docs from each users' personal drive + if self.include_personal: + if self.service_account_creds: + all_user_emails = self._get_all_user_emails() + for email in all_user_emails: + user_creds = self.service_account_creds.with_subject(email) + user_drive_service = build("drive", "v3", credentials=user_creds) + # we dont paginate here because there is only one root folder per user + # https://developers.google.com/drive/api/guides/v2-to-v3-reference + id = ( + user_drive_service.files() + .get(fileId="root", fields="id") + .execute()["id"] ) - except Exception as e: - if not self.continue_on_failure: - raise e - logger.exception( - "Ran into exception when pulling a file from Google Drive" + yield from self._crawl_drive_for_files( + service=user_drive_service, + parent_id=id, + personal_drive=True, + time_range_start=start, + time_range_end=end, ) - yield doc_batch + def get_primary_user_credentials( + self, + ) -> OAuthCredentials | ServiceAccountCredentials: + if self.service_account_creds: + creds = self.service_account_creds.with_subject(self.service_account_email) + service = build("drive", "v3", credentials=creds) + else: + service = build("drive", "v3", credentials=self.oauth_creds) + + return service + + def _fetch_docs_from_drive( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateDocumentsOutput: + if self.oauth_creds is None and self.service_account_creds is None: + raise PermissionError("No credentials found") + + service = self.get_primary_user_credentials() + + doc_batch = [] + for file in self._fetch_drive_items( + start=start, + end=end, + ): + if doc := _convert_drive_item_to_document(file, service): + doc_batch.append(doc) + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch = [] + + yield doc_batch + + def _fetch_slim_docs_from_drive( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: + slim_batch = [] + for file in self._fetch_drive_items( + start=start, + end=end, + ): + slim_batch.append( + SlimDocument( + id=file["webViewLink"], + perm_sync_data={ + "permissions": file.get("permissions", []), + "permission_ids": [ + perm["id"] for perm in file.get("permissionIds", []) + ], + }, + ) + ) + if len(slim_batch) >= _SLIM_BATCH_SIZE: + yield slim_batch + slim_batch = [] + yield slim_batch def load_from_state(self) -> GenerateDocumentsOutput: yield from self._fetch_docs_from_drive() + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: + self.is_slim = True + return self._fetch_slim_docs_from_drive(start, end) + def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - # need to subtract 10 minutes from start time to account for modifiedTime - # propogation if a document is modified, it takes some time for the API to - # reflect these changes if we do not have an offset, then we may "miss" the - # update when polling yield from self._fetch_docs_from_drive(start, end) - - -if __name__ == "__main__": - import json - import os - - service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH") - if not service_account_json_path: - raise ValueError( - "Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable" - ) - with open(service_account_json_path) as f: - creds = json.load(f) - - credentials_dict = { - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds), - } - delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER") - if delegated_user: - credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user - - connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True) - connector.load_credentials(credentials_dict) - document_batch_generator = connector.load_from_state() - for document_batch in document_batch_generator: - print(document_batch) - break diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index 19dbb845323..5cd0280891f 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -1,21 +1,16 @@ -from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any -from typing import cast from googleapiclient.discovery import build # type: ignore -from googleapiclient.errors import HttpError # type: ignore +from googleapiclient.discovery import Resource # type: ignore from sqlalchemy.orm import Session from danswer.access.models import ExternalAccess from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.factory import instantiate_connector -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds, -) -from danswer.connectors.interfaces import PollConnector -from danswer.connectors.models import InputType +from danswer.connectors.google_drive.connector import _execute_paginated_retrieval +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger @@ -29,21 +24,20 @@ logger = setup_logger() +_PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {} -def _get_docs_with_additional_info( - db_session: Session, + +def _get_slim_docs( cc_pair: ConnectorCredentialPair, -) -> dict[str, Any]: +) -> tuple[list[SlimDocument], GoogleDriveConnector]: # Get all document ids that need their permissions updated - runnable_connector = instantiate_connector( - db_session=db_session, - source=cc_pair.connector.source, - input_type=InputType.POLL, - connector_specific_config=cc_pair.connector.connector_specific_config, - credential=cc_pair.credential, - ) - assert isinstance(runnable_connector, PollConnector) + drive_connector = GoogleDriveConnector( + **cc_pair.connector.connector_specific_config + ) + drive_connector.load_credentials(cc_pair.credential.credential_json) + if drive_connector.service_account_creds is None: + raise ValueError("Service account credentials not found") current_time = datetime.now(timezone.utc) start_time = ( @@ -53,86 +47,83 @@ def _get_docs_with_additional_info( ) cc_pair.last_time_perm_sync = current_time - doc_batch_generator = runnable_connector.poll_source( + doc_batch_generator = drive_connector.retrieve_all_slim_documents( start=start_time, end=current_time.timestamp() ) + slim_docs = [doc for doc_batch in doc_batch_generator for doc in doc_batch] + + return slim_docs, drive_connector + + +def _fetch_permissions_for_permission_ids( + admin_service: Resource, + doc_id: str, + permission_ids: list[str], +) -> list[dict[str, Any]]: + # Check cache first for all permission IDs + permissions = [ + _PERMISSION_ID_PERMISSION_MAP[pid] + for pid in permission_ids + if pid in _PERMISSION_ID_PERMISSION_MAP + ] + + # If we found all permissions in cache, return them + if len(permissions) == len(permission_ids): + return permissions + + # Otherwise, fetch all permissions and update cache + fetched_permissions = _execute_paginated_retrieval( + retrieval_function=admin_service.permissions().list, + list_key="permissions", + fileId=doc_id, + fields="permissions(id, emailAddress, type, domain)", + supportsAllDrives=True, + ) + + permissions_for_doc_id = [] + # Update cache and return all permissions + for permission in fetched_permissions: + permissions_for_doc_id.append(permission) + _PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission - docs_with_additional_info = { - doc.id: doc.additional_info - for doc_batch in doc_batch_generator - for doc in doc_batch - } - - return docs_with_additional_info - - -def _fetch_permissions_paginated( - drive_service: Any, drive_file_id: str -) -> Iterator[dict[str, Any]]: - next_token = None - - # Get paginated permissions for the file id - while True: - try: - permissions_resp: dict[str, Any] = add_retries( - lambda: ( - drive_service.permissions() - .list( - fileId=drive_file_id, - fields="permissions(emailAddress, type, domain)", - supportsAllDrives=True, - pageToken=next_token, - ) - .execute() - ) - )() - except HttpError as e: - if e.resp.status == 404: - logger.warning(f"Document with id {drive_file_id} not found: {e}") - break - elif e.resp.status == 403: - logger.warning( - f"Access denied for retrieving document permissions: {e}" - ) - break - else: - logger.error(f"Failed to fetch permissions: {e}") - raise - - for permission in permissions_resp.get("permissions", []): - yield permission - - next_token = permissions_resp.get("nextPageToken") - if not next_token: - break - - -def _fetch_google_permissions_for_document_id( + return permissions_for_doc_id + + +def _fetch_google_permissions_for_slim_doc( db_session: Session, - drive_file_id: str, - credentials_json: dict[str, str], - company_google_domains: list[str], + admin_service: Resource, + slim_doc: SlimDocument, + company_domain: str | None, ) -> ExternalAccess: - # Authenticate and construct service - google_drive_creds, _ = get_google_drive_creds( - credentials_json, - ) - if not google_drive_creds.valid: - raise ValueError("Invalid Google Drive credentials") - - drive_service = build("drive", "v3", credentials=google_drive_creds) + permission_info = slim_doc.perm_sync_data or {} + + permissions_list = permission_info.get("permissions", []) + if not permissions_list: + if permission_ids := permission_info.get("permissionIds"): + permissions_list = _fetch_permissions_for_permission_ids( + admin_service=admin_service, + doc_id=slim_doc.id, + permission_ids=permission_ids, + ) + if not permissions_list: + logger.warning(f"No permissions found for document {slim_doc.id}") + return ExternalAccess( + external_user_emails=set(), + external_user_group_ids=set(), + is_public=False, + ) user_emails: set[str] = set() group_emails: set[str] = set() public = False - for permission in _fetch_permissions_paginated(drive_service, drive_file_id): + for permission in permissions_list: permission_type = permission["type"] if permission_type == "user": user_emails.add(permission["emailAddress"]) elif permission_type == "group": group_emails.add(permission["emailAddress"]) - elif permission_type == "domain": - if permission["domain"] in company_google_domains: + elif permission_type == "domain" and company_domain: + if permission["domain"] == company_domain: public = True elif permission_type == "anyone": public = True @@ -161,27 +152,22 @@ def gdrive_doc_sync( logger.error("Sync details not found for Google Drive") raise ValueError("Sync details not found for Google Drive") - # Here we run the connector to grab all the ids - # this may grab ids before they are indexed but that is fine because - # we create a document in postgres to hold the permissions info - # until the indexing job has a chance to run - docs_with_additional_info = _get_docs_with_additional_info( - db_session=db_session, - cc_pair=cc_pair, - ) + slim_docs, google_drive_connector = _get_slim_docs(cc_pair) + + creds = google_drive_connector.get_primary_user_credentials() + admin_creds = creds.with_subject(google_drive_connector.service_account_email) + admin_service = build("admin", "directory_v1", credentials=admin_creds) - for doc_id, doc_additional_info in docs_with_additional_info.items(): - ext_access = _fetch_google_permissions_for_document_id( + for slim_doc in slim_docs: + ext_access = _fetch_google_permissions_for_slim_doc( db_session=db_session, - drive_file_id=doc_additional_info, - credentials_json=cc_pair.credential.credential_json, - company_google_domains=[ - cast(dict[str, str], sync_details)["company_domain"] - ], + admin_service=admin_service, + slim_doc=slim_doc, + company_domain=google_drive_connector.service_account_domain, ) upsert_document_external_perms__no_commit( db_session=db_session, - doc_id=doc_id, + doc_id=slim_doc.id, external_access=ext_access, source_type=cc_pair.connector.source, ) From db46ffd62d634aa08e59f6e7265ed7c3ca6e39a7 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 28 Oct 2024 13:13:07 -0700 Subject: [PATCH 03/23] works with service account --- .../connectors/google_drive/connector.py | 144 ++++++++--------- .../connectors/google_drive/connector_auth.py | 145 ++++++++---------- .../connectors/google_drive/constants.py | 6 - backend/danswer/server/documents/connector.py | 2 +- .../ee/danswer/background/celery/apps/beat.py | 4 +- .../google_drive/doc_sync.py | 48 ++---- .../google_drive/group_sync.py | 137 ++++------------- .../external_permissions/permission_sync.py | 2 + .../[connector]/pages/gdrive/Credential.tsx | 8 +- .../lib/connectors/AutoSyncOptionFields.tsx | 33 +--- web/src/lib/connectors/connectors.tsx | 35 ++--- 11 files changed, 200 insertions(+), 364 deletions(-) delete mode 100644 backend/danswer/connectors/google_drive/constants.py diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 7d2697049fe..cfe77428735 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -18,10 +18,10 @@ from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.google_drive.connector_auth import get_google_drive_creds -from danswer.connectors.google_drive.constants import ( +from danswer.connectors.google_drive.connector_auth import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) +from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector @@ -44,10 +44,8 @@ DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now -FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" -SLIM_FILE_FIELDS = ( - "nextPageToken, files(permissions(emailAddress, type), webViewLink), permissionIds" -) +FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" +SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" # these errors don't represent a failure in the connector, but simply files @@ -59,8 +57,6 @@ ] _SLIM_BATCH_SIZE = 500 -_TRAVERSED_PARENT_IDS: set[str] = set() - class GDriveMimeType(str, Enum): DOC = "application/vnd.google-apps.document" @@ -83,7 +79,7 @@ class GDriveMimeType(str, Enum): add_retries = retry_builder(tries=50, max_delay=30) -def _execute_paginated_retrieval( +def execute_paginated_retrieval( retrieval_function: Callable[..., Any], list_key: str, **kwargs: Any, @@ -93,7 +89,6 @@ def _execute_paginated_retrieval( retrieval_function: The specific list function to call (e.g., service.files().list) **kwargs: Arguments to pass to the list function """ - print("\n -------------------------------") next_page_token = "" while next_page_token is not None: request_kwargs = kwargs.copy() @@ -205,7 +200,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, parent_urls: list[str] | None = None, - include_personal: bool = True, + include_personal: bool | None = True, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size @@ -213,7 +208,7 @@ def __init__( self.parent_ids = ( _extract_parent_ids_from_urls(parent_urls) if parent_urls else [] ) - self.include_personal = include_personal + self.include_personal = include_personal or True self.service_account_email: str | None = None self.service_account_domain: str | None = None @@ -223,6 +218,8 @@ def __init__( self.is_slim: bool = False + self._TRAVERSED_PARENT_IDS: set[str] = set() + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: """Checks for two different types of credentials. (1) A credential which holds a token acquired via a user going thorough @@ -230,6 +227,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. """ + self.credentials_json = credentials creds, new_creds_dict = get_google_drive_creds(credentials) if KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: @@ -243,6 +241,22 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None self.oauth_creds = creds return new_creds_dict + def get_admin_service( + self, + service_name: str = "drive", + service_version: str = "v3", + user_email: str | None = None, + ) -> Resource: + if self.service_account_creds: + creds = self.service_account_creds.with_subject( + user_email or self.service_account_email + ) + service = build(service_name, service_version, credentials=creds) + else: + service = build(service_name, service_version, credentials=self.oauth_creds) + + return service + def _get_folders_in_parent( self, service: Resource, @@ -257,12 +271,12 @@ def _get_folders_in_parent( if parent_id: query += f" and '{parent_id}' in parents" - for file in _execute_paginated_retrieval( + for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", - corpora="user" if personal_drive else "domain", - supportsAllDrives=personal_drive, - includeItemsFromAllDrives=personal_drive, + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, fields=FOLDER_FIELDS, q=query, ): @@ -284,12 +298,12 @@ def _get_files_in_parent( time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" query += f" and modifiedTime <= '{time_stop}'" - for file in _execute_paginated_retrieval( + for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", - corpora="user" if personal_drive else "domain", - supportsAllDrives=True, - includeItemsFromAllDrives=True, + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, fields=SLIM_FILE_FIELDS if self.is_slim else FILE_FIELDS, q=query, ): @@ -306,11 +320,11 @@ def _crawl_drive_for_files( """Gets all files matching the criteria specified by the args from Google Drive in batches of size `batch_size`. """ - if parent_id in _TRAVERSED_PARENT_IDS: + if parent_id in self._TRAVERSED_PARENT_IDS: logger.debug(f"Skipping subfolder since already traversed: {parent_id}") return - _TRAVERSED_PARENT_IDS.add(parent_id) + self._TRAVERSED_PARENT_IDS.add(parent_id) yield from self._get_files_in_parent( service=service, @@ -335,15 +349,15 @@ def _crawl_drive_for_files( ) def _get_all_user_emails(self) -> list[str]: - if not self.service_account_creds: - raise PermissionError("No service account credentials found") + # if not self.service_account_creds: + # raise PermissionError("No service account credentials found") admin_creds = self.service_account_creds.with_subject( self.service_account_email ) admin_service = build("admin", "directory_v1", credentials=admin_creds) emails = [] - for user in _execute_paginated_retrieval( + for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", domain=self.service_account_domain, @@ -358,13 +372,12 @@ def _fetch_drive_items( end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: # admin_creds = self.service_account_creds.with_subject(self.service_account_email) - admin_creds = self.get_primary_user_credentials() - admin_drive_service = build("drive", "v3", credentials=admin_creds) + admin_drive_service = self.get_admin_service() parent_ids = self.parent_ids if not parent_ids: # if no parent ids are specified, get all shared drives using the admin account - for drive in _execute_paginated_retrieval( + for drive in execute_paginated_retrieval( retrieval_function=admin_drive_service.drives().list, list_key="drives", useDomainAdminAccess=True, @@ -374,64 +387,58 @@ def _fetch_drive_items( # crawl all the shared parent ids for files for parent_id in parent_ids: - yield from self._crawl_drive_for_files( + for file in self._crawl_drive_for_files( service=admin_drive_service, parent_id=parent_id, personal_drive=False, time_range_start=start, time_range_end=end, - ) - + ): + print(file) + yield file + logger.info(f"Fetching personal files: {self.include_personal}") # get all personal docs from each users' personal drive if self.include_personal: - if self.service_account_creds: - all_user_emails = self._get_all_user_emails() - for email in all_user_emails: - user_creds = self.service_account_creds.with_subject(email) - user_drive_service = build("drive", "v3", credentials=user_creds) - # we dont paginate here because there is only one root folder per user - # https://developers.google.com/drive/api/guides/v2-to-v3-reference - id = ( - user_drive_service.files() - .get(fileId="root", fields="id") - .execute()["id"] - ) - - yield from self._crawl_drive_for_files( - service=user_drive_service, - parent_id=id, - personal_drive=True, - time_range_start=start, - time_range_end=end, - ) - - def get_primary_user_credentials( - self, - ) -> OAuthCredentials | ServiceAccountCredentials: - if self.service_account_creds: - creds = self.service_account_creds.with_subject(self.service_account_email) - service = build("drive", "v3", credentials=creds) - else: - service = build("drive", "v3", credentials=self.oauth_creds) + all_user_emails = self._get_all_user_emails() + for email in all_user_emails: + logger.info(f"Fetching personal files for user: {email}") + user_creds = self.service_account_creds.with_subject(email) + user_drive_service = build("drive", "v3", credentials=user_creds) + # we dont paginate here because there is only one root folder per user + # https://developers.google.com/drive/api/guides/v2-to-v3-reference + id = ( + user_drive_service.files() + .get(fileId="root", fields="id") + .execute()["id"] + ) - return service + yield from self._crawl_drive_for_files( + service=user_drive_service, + parent_id=id, + personal_drive=True, + time_range_start=start, + time_range_end=end, + ) def _fetch_docs_from_drive( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: - if self.oauth_creds is None and self.service_account_creds is None: - raise PermissionError("No credentials found") - - service = self.get_primary_user_credentials() + # if self.oauth_creds is None and self.service_account_creds is None: + # raise PermissionError("No credentials found") doc_batch = [] for file in self._fetch_drive_items( start=start, end=end, ): - if doc := _convert_drive_item_to_document(file, service): + user_email = file.get("owners", [{}])[0].get("emailAddress") + service = self.get_admin_service(user_email=user_email) + if doc := _convert_drive_item_to_document( + file=file, + service=service, + ): doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield doc_batch @@ -453,10 +460,9 @@ def _fetch_slim_docs_from_drive( SlimDocument( id=file["webViewLink"], perm_sync_data={ + "doc_id": file.get("id"), "permissions": file.get("permissions", []), - "permission_ids": [ - perm["id"] for perm in file.get("permissionIds", []) - ], + "permission_ids": file.get("permissionIds", []), }, ) ) diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 464e59cd798..b2a8e950e64 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -10,19 +10,11 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from sqlalchemy.orm import Session -from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY -from danswer.connectors.google_drive.constants import BASE_SCOPES -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, -) -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY -from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES -from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User from danswer.key_value_store.factory import get_kv_store @@ -33,15 +25,16 @@ logger = setup_logger() - -def build_gdrive_scopes() -> list[str]: - base_scopes: list[str] = BASE_SCOPES - permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES - groups_scopes: list[str] = FETCH_GROUPS_SCOPES - - if ENTERPRISE_EDITION_ENABLED: - return base_scopes + permissions_scopes + groups_scopes - return base_scopes + permissions_scopes +GOOGLE_DRIVE_SCOPES = [ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/drive.metadata.readonly", + "https://www.googleapis.com/auth/admin.directory.group.readonly", +] +SERVICE_ACCOUNT_SCOPES = GOOGLE_DRIVE_SCOPES + [ + "https://www.googleapis.com/auth/admin.directory.user.readonly", +] +DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" +DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" def _build_frontend_google_drive_redirect() -> str: @@ -49,7 +42,7 @@ def _build_frontend_google_drive_redirect() -> str: def get_google_drive_creds_for_authorized_user( - token_json_str: str, scopes: list[str] = build_gdrive_scopes() + token_json_str: str, scopes: list[str] ) -> OAuthCredentials | None: creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes) @@ -69,59 +62,47 @@ def get_google_drive_creds_for_authorized_user( return None -def _get_google_drive_creds_for_service_account( - service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes() -) -> ServiceAccountCredentials | None: - service_account_key = json.loads(service_account_key_json_str) - creds = ServiceAccountCredentials.from_service_account_info( - service_account_key, scopes=scopes - ) - if not creds.valid or not creds.expired: - creds.refresh(Request()) - return creds if creds.valid else None - - -def get_service_account_credentials( - credentials: dict[str, str], - scopes: list[str] = build_gdrive_scopes(), -) -> ServiceAccountCredentials: - service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] - service_creds = _get_google_drive_creds_for_service_account( - service_account_key_json_str=service_account_key_json_str, - scopes=scopes, - ) - - # "Impersonate" a user if one is specified - delegated_user_email = cast( - str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) - ) - if delegated_user_email: - service_creds = ( - service_creds.with_subject(delegated_user_email) if service_creds else None - ) - return service_creds - - -def get_oauth_credentials( - credentials: dict[str, str], - scopes: list[str] = build_gdrive_scopes(), -) -> tuple[OAuthCredentials | None, dict[str, str] | None]: - new_creds_dict = None - access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) - oauth_creds = get_google_drive_creds_for_authorized_user( - token_json_str=access_token_json_str, scopes=scopes - ) - - # tell caller to update token stored in DB if it has changed - # (e.g. the token has been refreshed) - new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" - if new_creds_json_str != access_token_json_str: - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} - return oauth_creds, new_creds_dict +# def get_service_account_credentials( +# credentials: dict[str, str], +# scopes: list[str], +# ) -> ServiceAccountCredentials: +# service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] +# service_creds = _get_google_drive_creds_for_service_account( +# service_account_key_json_str=service_account_key_json_str, +# scopes=scopes, +# ) + +# # "Impersonate" a user if one is specified +# delegated_user_email = cast( +# str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) +# ) +# if delegated_user_email: +# service_creds = ( +# service_creds.with_subject(delegated_user_email) if service_creds else None +# ) +# return service_creds + + +# def get_oauth_credentials( +# credentials: dict[str, str], +# scopes: list[str] +# ) -> tuple[OAuthCredentials | None, dict[str, str] | None]: +# new_creds_dict = None +# access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) +# oauth_creds = get_google_drive_creds_for_authorized_user( +# token_json_str=access_token_json_str, scopes=scopes +# ) + +# # tell caller to update token stored in DB if it has changed +# # (e.g. the token has been refreshed) +# new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" +# if new_creds_json_str != access_token_json_str: +# new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} +# return oauth_creds, new_creds_dict def get_google_drive_creds( - credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes() + credentials: dict[str, str], scopes: list[str] = SERVICE_ACCOUNT_SCOPES ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: oauth_creds = None service_creds = None @@ -140,26 +121,24 @@ def get_google_drive_creds( elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] - service_creds = _get_google_drive_creds_for_service_account( - service_account_key_json_str=service_account_key_json_str, - scopes=scopes, - ) + service_account_key = json.loads(service_account_key_json_str) - # "Impersonate" a user if one is specified - delegated_user_email = cast( - str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) + service_creds = ServiceAccountCredentials.from_service_account_info( + service_account_key, scopes=scopes ) - if delegated_user_email: - service_creds = ( - service_creds.with_subject(delegated_user_email) - if service_creds - else None + + if not service_creds.valid or not service_creds.expired: + service_creds.refresh(Request()) + + if not service_creds.valid: + raise PermissionError( + "Unable to access Google Drive - service account credentials are invalid." ) creds: ServiceAccountCredentials | OAuthCredentials | None = ( oauth_creds or service_creds ) - if creds is None: + if service_creds is None: raise PermissionError( "Unable to access Google Drive - unknown credential structure." ) @@ -180,7 +159,7 @@ def get_auth_url(credential_id: int) -> str: credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, - scopes=build_gdrive_scopes(), + scopes=SERVICE_ACCOUNT_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) auth_url, _ = flow.authorization_url(prompt="consent") @@ -203,7 +182,7 @@ def update_credential_access_tokens( app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), - scopes=build_gdrive_scopes(), + scopes=SERVICE_ACCOUNT_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) flow.fetch_token(code=auth_code) diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py deleted file mode 100644 index 563f2c63b47..00000000000 --- a/backend/danswer/connectors/google_drive/constants.py +++ /dev/null @@ -1,6 +0,0 @@ -DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" -DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" - -BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] -FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] -FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"] diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 1ba0ab13e2c..afd41a296aa 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -35,6 +35,7 @@ ) from danswer.connectors.gmail.connector_auth import upsert_google_app_gmail_cred from danswer.connectors.google_drive.connector_auth import build_service_account_creds +from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.connectors.google_drive.connector_auth import delete_google_app_cred from danswer.connectors.google_drive.connector_auth import delete_service_account_key from danswer.connectors.google_drive.connector_auth import get_auth_url @@ -49,7 +50,6 @@ from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred from danswer.connectors.google_drive.connector_auth import upsert_service_account_key from danswer.connectors.google_drive.connector_auth import verify_csrf -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.db.connector import create_connector from danswer.db.connector import delete_connector from danswer.db.connector import fetch_connector_by_id diff --git a/backend/ee/danswer/background/celery/apps/beat.py b/backend/ee/danswer/background/celery/apps/beat.py index bee219e2471..980eb5e3214 100644 --- a/backend/ee/danswer/background/celery/apps/beat.py +++ b/backend/ee/danswer/background/celery/apps/beat.py @@ -13,12 +13,12 @@ { "name": "sync-external-doc-permissions", "task": "check_sync_external_doc_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this + "schedule": timedelta(seconds=30), # TODO: optimize this }, { "name": "sync-external-group-permissions", "task": "check_sync_external_group_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this + "schedule": timedelta(seconds=60), # TODO: optimize this }, { "name": "autogenerate_usage_report", diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index 5cd0280891f..2c1351dccda 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -2,13 +2,11 @@ from datetime import timezone from typing import Any -from googleapiclient.discovery import build # type: ignore from googleapiclient.discovery import Resource # type: ignore from sqlalchemy.orm import Session from danswer.access.models import ExternalAccess -from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.google_drive.connector import _execute_paginated_retrieval +from danswer.connectors.google_drive.connector import execute_paginated_retrieval from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair @@ -16,12 +14,6 @@ from danswer.utils.logger import setup_logger from ee.danswer.db.document import upsert_document_external_perms__no_commit -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=5, delay=5, max_delay=30) - - logger = setup_logger() _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {} @@ -29,30 +21,21 @@ def _get_slim_docs( cc_pair: ConnectorCredentialPair, + google_drive_connector: GoogleDriveConnector, ) -> tuple[list[SlimDocument], GoogleDriveConnector]: - # Get all document ids that need their permissions updated - - drive_connector = GoogleDriveConnector( - **cc_pair.connector.connector_specific_config - ) - drive_connector.load_credentials(cc_pair.credential.credential_json) - if drive_connector.service_account_creds is None: - raise ValueError("Service account credentials not found") - current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() if cc_pair.last_time_perm_sync else 0.0 ) - cc_pair.last_time_perm_sync = current_time - doc_batch_generator = drive_connector.retrieve_all_slim_documents( + doc_batch_generator = google_drive_connector.retrieve_all_slim_documents( start=start_time, end=current_time.timestamp() ) slim_docs = [doc for doc_batch in doc_batch_generator for doc in doc_batch] - return slim_docs, drive_connector + return slim_docs def _fetch_permissions_for_permission_ids( @@ -72,7 +55,7 @@ def _fetch_permissions_for_permission_ids( return permissions # Otherwise, fetch all permissions and update cache - fetched_permissions = _execute_paginated_retrieval( + fetched_permissions = execute_paginated_retrieval( retrieval_function=admin_service.permissions().list, list_key="permissions", fileId=doc_id, @@ -98,11 +81,12 @@ def _fetch_google_permissions_for_slim_doc( permission_info = slim_doc.perm_sync_data or {} permissions_list = permission_info.get("permissions", []) + doc_id = permission_info.get("doc_id") if not permissions_list: - if permission_ids := permission_info.get("permissionIds"): + if permission_ids := permission_info.get("permission_ids") and doc_id: permissions_list = _fetch_permissions_for_permission_ids( admin_service=admin_service, - doc_id=slim_doc.id, + doc_id=doc_id, permission_ids=permission_ids, ) if not permissions_list: @@ -147,16 +131,16 @@ def gdrive_doc_sync( it in postgres so that when it gets created later, the permissions are already populated """ - sync_details = cc_pair.auto_sync_options - if sync_details is None: - logger.error("Sync details not found for Google Drive") - raise ValueError("Sync details not found for Google Drive") + google_drive_connector = GoogleDriveConnector( + **cc_pair.connector.connector_specific_config + ) + google_drive_connector.load_credentials(cc_pair.credential.credential_json) - slim_docs, google_drive_connector = _get_slim_docs(cc_pair) + if google_drive_connector.service_account_creds is None: + raise ValueError("Service account credentials not found") - creds = google_drive_connector.get_primary_user_credentials() - admin_creds = creds.with_subject(google_drive_connector.service_account_email) - admin_service = build("admin", "directory_v1", credentials=admin_creds) + slim_docs = _get_slim_docs(cc_pair, google_drive_connector) + admin_service = google_drive_connector.get_admin_service() for slim_doc in slim_docs: ext_access = _fetch_google_permissions_for_slim_doc( diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index 7bb919d4686..b3f133ccbfd 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -1,17 +1,7 @@ -from collections.abc import Iterator -from typing import Any - -from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore -from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore -from googleapiclient.discovery import build # type: ignore -from googleapiclient.errors import HttpError # type: ignore from sqlalchemy.orm import Session -from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds, -) -from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.connectors.google_drive.connector import execute_paginated_retrieval +from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger @@ -21,116 +11,41 @@ logger = setup_logger() -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=5, delay=5, max_delay=30) - - -def _fetch_groups_paginated( - google_drive_creds: ServiceAccountCredentials | OAuthCredentials, - identity_source: str | None = None, - customer_id: str | None = None, -) -> Iterator[dict[str, Any]]: - # Note that Google Drive does not use of update the user_cache as the user email - # comes directly with the call to fetch the groups, therefore this is not a valid - # place to save on requests - if identity_source is None and customer_id is None: - raise ValueError( - "Either identity_source or customer_id must be provided to fetch groups" - ) - - cloud_identity_service = build( - "cloudidentity", "v1", credentials=google_drive_creds - ) - parent = ( - f"identitysources/{identity_source}" - if identity_source - else f"customers/{customer_id}" - ) - - while True: - try: - groups_resp: dict[str, Any] = add_retries( - lambda: (cloud_identity_service.groups().list(parent=parent).execute()) - )() - for group in groups_resp.get("groups", []): - yield group - - next_token = groups_resp.get("nextPageToken") - if not next_token: - break - except HttpError as e: - if e.resp.status == 404 or e.resp.status == 403: - break - logger.error(f"Error fetching groups: {e}") - raise - - -def _fetch_group_members_paginated( - google_drive_creds: ServiceAccountCredentials | OAuthCredentials, - group_name: str, -) -> Iterator[dict[str, Any]]: - cloud_identity_service = build( - "cloudidentity", "v1", credentials=google_drive_creds - ) - next_token = None - while True: - try: - membership_info = add_retries( - lambda: ( - cloud_identity_service.groups() - .memberships() - .searchTransitiveMemberships( - parent=group_name, pageToken=next_token - ) - .execute() - ) - )() - - for member in membership_info.get("memberships", []): - yield member - - next_token = membership_info.get("nextPageToken") - if not next_token: - break - except HttpError as e: - if e.resp.status == 404 or e.resp.status == 403: - break - logger.error(f"Error fetching group members: {e}") - raise - - def gdrive_group_sync( db_session: Session, cc_pair: ConnectorCredentialPair, ) -> None: - sync_details = cc_pair.auto_sync_options - if sync_details is None: - logger.error("Sync details not found for Google Drive") - raise ValueError("Sync details not found for Google Drive") - - google_drive_creds, _ = get_google_drive_creds( - cc_pair.credential.credential_json, - scopes=FETCH_GROUPS_SCOPES, + google_drive_connector = GoogleDriveConnector( + **cc_pair.connector.connector_specific_config ) + google_drive_connector.load_credentials(cc_pair.credential.credential_json) + + if google_drive_connector.service_account_creds is None: + raise ValueError("Service account credentials not found") + + admin_service = google_drive_connector.get_admin_service("admin", "directory_v1") danswer_groups: list[ExternalUserGroup] = [] - for group in _fetch_groups_paginated( - google_drive_creds, - identity_source=sync_details.get("identity_source"), - customer_id=sync_details.get("customer_id"), + for group in execute_paginated_retrieval( + admin_service.groups().list, + list_key="groups", + domain=google_drive_connector.service_account_domain, + fields="groups(email)", ): # The id is the group email - group_email = group["groupKey"]["id"] + group_email = group["email"] + # Gather group member emails group_member_emails: list[str] = [] - for member in _fetch_group_members_paginated(google_drive_creds, group["name"]): - member_keys = member["preferredMemberKey"] - member_emails = [member_key["id"] for member_key in member_keys] - for member_email in member_emails: - group_member_emails.append(member_email) - + for member in execute_paginated_retrieval( + admin_service.members().list, + list_key="members", + groupKey=group_email, + fields="members(email)", + ): + group_member_emails.append(member["email"]) + + # Add group members to DB and get their IDs group_members = batch_add_non_web_user_if_not_exists__no_commit( db_session=db_session, emails=group_member_emails ) diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py index ba5bbbd4921..94a0b4bfa8e 100644 --- a/backend/ee/danswer/external_permissions/permission_sync.py +++ b/backend/ee/danswer/external_permissions/permission_sync.py @@ -59,6 +59,7 @@ def run_external_doc_permission_sync( source_type = cc_pair.connector.source doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + last_time_perm_sync = cc_pair.last_time_perm_sync if doc_sync_func is None: raise ValueError( @@ -110,4 +111,5 @@ def run_external_doc_permission_sync( logger.info(f"Successfully synced docs for {source_type}") except Exception: logger.exception("Error Syncing Document Permissions") + cc_pair.last_time_perm_sync = last_time_perm_sync db_session.rollback() diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 371bbef6dd1..9b2ba9ef26d 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -372,7 +372,9 @@ export const DriveOAuthSection = ({ google_drive_delegated_user: "", }} validationSchema={Yup.object().shape({ - google_drive_delegated_user: Yup.string().optional(), + google_drive_delegated_user: Yup.string().required( + "User email is required" + ), })} onSubmit={async (values, formikHelpers) => { formikHelpers.setSubmitting(true); @@ -409,8 +411,8 @@ export const DriveOAuthSection = ({
diff --git a/web/src/lib/connectors/AutoSyncOptionFields.tsx b/web/src/lib/connectors/AutoSyncOptionFields.tsx index f6866a16991..4a8b44868e6 100644 --- a/web/src/lib/connectors/AutoSyncOptionFields.tsx +++ b/web/src/lib/connectors/AutoSyncOptionFields.tsx @@ -12,37 +12,6 @@ export const autoSyncConfigBySource: Record< > > = { confluence: {}, - google_drive: { - customer_id: { - label: "Google Workspace Customer ID", - subtext: ( - <> - The unique identifier for your Google Workspace account. To find this, - checkout the{" "} - - guide from Google - - . - - ), - }, - company_domain: { - label: "Google Workspace Company Domain", - subtext: ( - <> - The email domain for your Google Workspace account. -
-
- For example, if your email provided through Google Workspace looks - something like chris@danswer.ai, then your company domain is{" "} - danswer.ai - - ), - }, - }, + google_drive: {}, slack: {}, }; diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index d722fcf9848..d4a9daab99b 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -202,40 +202,25 @@ export const connectorConfigs: Record< }, google_drive: { description: "Configure Google Drive connector", - values: [ + values: [], + advanced_values: [ { type: "list", - query: "Enter folder paths:", - label: "Folder Paths", - name: "folder_paths", + query: "Enter the URLs of the shared folders or drives to index:", + label: "Parent URLs To Index", + name: "parent_urls", optional: true, }, { type: "checkbox", - query: "Include shared files?", - label: "Include Shared", - name: "include_shared", - optional: false, - default: false, - }, - { - type: "checkbox", - query: "Follow shortcuts?", - label: "Follow Shortcuts", - name: "follow_shortcuts", - optional: false, - default: false, - }, - { - type: "checkbox", - query: "Only include organization public files?", - label: "Only Org Public", - name: "only_org_public", + query: + "Include personal drives? (Note: This should only be used if you use permissions sync)", + label: "Include personal", + name: "include_personal", optional: false, default: false, }, ], - advanced_values: [], }, gmail: { description: "Configure Gmail connector", @@ -1030,7 +1015,7 @@ export interface GitlabConfig { } export interface GoogleDriveConfig { - folder_paths?: string[]; + parent_urls?: string[]; include_shared?: boolean; follow_shortcuts?: boolean; only_org_public?: boolean; From 08aececf084d552e444794acb7bbd7ed62d717e1 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 28 Oct 2024 13:13:49 -0700 Subject: [PATCH 04/23] combined scopes --- backend/danswer/connectors/google_drive/connector_auth.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index b2a8e950e64..390c0b9e469 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -29,8 +29,6 @@ "https://www.googleapis.com/auth/drive.readonly", "https://www.googleapis.com/auth/drive.metadata.readonly", "https://www.googleapis.com/auth/admin.directory.group.readonly", -] -SERVICE_ACCOUNT_SCOPES = GOOGLE_DRIVE_SCOPES + [ "https://www.googleapis.com/auth/admin.directory.user.readonly", ] DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" @@ -102,7 +100,7 @@ def get_google_drive_creds_for_authorized_user( def get_google_drive_creds( - credentials: dict[str, str], scopes: list[str] = SERVICE_ACCOUNT_SCOPES + credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: oauth_creds = None service_creds = None @@ -159,7 +157,7 @@ def get_auth_url(credential_id: int) -> str: credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, - scopes=SERVICE_ACCOUNT_SCOPES, + scopes=GOOGLE_DRIVE_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) auth_url, _ = flow.authorization_url(prompt="consent") @@ -182,7 +180,7 @@ def update_credential_access_tokens( app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), - scopes=SERVICE_ACCOUNT_SCOPES, + scopes=GOOGLE_DRIVE_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) flow.fetch_token(code=auth_code) From d69d4aff6a790a9613c6001189a951a314a537eb Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 28 Oct 2024 16:06:54 -0700 Subject: [PATCH 05/23] copy change --- .../[connector]/pages/gdrive/Credential.tsx | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 9b2ba9ef26d..4134cde4c2a 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -356,15 +356,13 @@ export const DriveOAuthSection = ({ return (

- When using a Google Drive Service Account, you can either have Danswer - act as the service account itself OR you can specify an account for - the service account to impersonate. + When using a Google Drive Service Account, you must speicify the email + of the primary admin that you would like the service account to + impersonate.

- If you want to use the service account itself, leave the{" "} - 'User email to impersonate' field blank when - submitting. If you do choose this option, make sure you have shared - the documents you want to index with the service account. + Ideally, this account should be the owner of the Google Organization + that owns the Google Drive you want to index.

Date: Tue, 29 Oct 2024 10:33:55 -0700 Subject: [PATCH 06/23] oauth prep --- .../danswer/connectors/google_drive/connector.py | 16 ++++++++-------- .../connectors/google_drive/connector_auth.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index cfe77428735..34504ed54e1 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -446,6 +446,14 @@ def _fetch_docs_from_drive( yield doc_batch + def load_from_state(self) -> GenerateDocumentsOutput: + yield from self._fetch_docs_from_drive() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + yield from self._fetch_docs_from_drive(start, end) + def _fetch_slim_docs_from_drive( self, start: SecondsSinceUnixEpoch | None = None, @@ -471,9 +479,6 @@ def _fetch_slim_docs_from_drive( slim_batch = [] yield slim_batch - def load_from_state(self) -> GenerateDocumentsOutput: - yield from self._fetch_docs_from_drive() - def retrieve_all_slim_documents( self, start: SecondsSinceUnixEpoch | None = None, @@ -481,8 +486,3 @@ def retrieve_all_slim_documents( ) -> GenerateSlimDocumentOutput: self.is_slim = True return self._fetch_slim_docs_from_drive(start, end) - - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: - yield from self._fetch_docs_from_drive(start, end) diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 390c0b9e469..7a7853d31f9 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -136,7 +136,7 @@ def get_google_drive_creds( creds: ServiceAccountCredentials | OAuthCredentials | None = ( oauth_creds or service_creds ) - if service_creds is None: + if creds is None: raise PermissionError( "Unable to access Google Drive - unknown credential structure." ) From 39c403415058b8214c2f09e4f7a77a0cb557620f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 29 Oct 2024 14:53:23 -0700 Subject: [PATCH 07/23] Works for oauth and service account credentials --- .../connectors/google_drive/connector.py | 145 +++++++++--------- .../connectors/google_drive/connector_auth.py | 49 +----- .../google_drive/doc_sync.py | 4 +- .../google_drive/group_sync.py | 4 +- 4 files changed, 81 insertions(+), 121 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 34504ed54e1..a70f75fc027 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -16,10 +16,9 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder from danswer.connectors.google_drive.connector_auth import ( - DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -40,17 +39,26 @@ logger = setup_logger() -DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" -DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" -UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now +_DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" +_DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" +_UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now -FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" -SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" -FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" +_FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" +_SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" +_FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" +_USER_FIELDS = "nextPageToken, users(primaryEmail)" +# This is a substring of the error google returns when the user doesn't have the correct scopes. +_MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" + +_SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" +_ONYX_SCOPE_INSTRUCTIONS = ( + "You have upgraded Danswer without updating the Google Drive scopes. " + f"Please refer to the documentation to learn how to update the scopes: {_SCOPE_DOC_URL}" +) # these errors don't represent a failure in the connector, but simply files # that can't / shouldn't be indexed -ERRORS_TO_CONTINUE_ON = [ +_ERRORS_TO_CONTINUE_ON = [ "cannotExportFile", "exportSizeLimitExceeded", "cannotDownloadFile", @@ -107,7 +115,7 @@ def extract_text(file: dict[str, str], service: Resource) -> str: if mime_type not in set(item.value for item in GDriveMimeType): # Unsupported file types can still have a title, finding this way is still useful - return UNSUPPORTED_FILE_TYPE_CONTENT + return _UNSUPPORTED_FILE_TYPE_CONTENT if mime_type in [ GDriveMimeType.DOC.value, @@ -149,7 +157,7 @@ def extract_text(file: dict[str, str], service: Resource) -> str: elif mime_type == GDriveMimeType.POWERPOINT.value: return pptx_to_text(file=io.BytesIO(response)) - return UNSUPPORTED_FILE_TYPE_CONTENT + return _UNSUPPORTED_FILE_TYPE_CONTENT def _convert_drive_item_to_document( @@ -157,7 +165,7 @@ def _convert_drive_item_to_document( ) -> Document | None: try: # Skip files that are shortcuts - if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: + if file.get("mimeType") == _DRIVE_SHORTCUT_TYPE: logger.info("Ignoring Drive Shortcut Filetype") return None try: @@ -165,7 +173,7 @@ def _convert_drive_item_to_document( except HttpError as e: reason = e.error_details[0]["reason"] if e.error_details else e.reason message = e.error_details[0]["message"] if e.error_details else e.reason - if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + if e.status_code == 403 and reason in _ERRORS_TO_CONTINUE_ON: logger.warning( f"Could not export file '{file['name']}' due to '{message}', skipping..." ) @@ -192,10 +200,6 @@ def _convert_drive_item_to_document( return None -def _extract_parent_ids_from_urls(urls: list[str]) -> list[str]: - return [url.split("/")[-1] for url in urls] - - class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, @@ -205,55 +209,40 @@ def __init__( ) -> None: self.batch_size = batch_size - self.parent_ids = ( - _extract_parent_ids_from_urls(parent_urls) if parent_urls else [] - ) + self.initial_parent_ids = [] + if parent_urls: + self.initial_parent_ids = [url.split("/")[-1] for url in parent_urls] self.include_personal = include_personal or True - self.service_account_email: str | None = None - self.service_account_domain: str | None = None - self.service_account_creds: ServiceAccountCredentials | None = None + self.primary_admin_email: str | None = None + self.google_domain: str | None = None - self.oauth_creds: OAuthCredentials | None = None + self.creds: OAuthCredentials | ServiceAccountCredentials | None = None self.is_slim: bool = False self._TRAVERSED_PARENT_IDS: set[str] = set() def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: - """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorough - the Google OAuth flow. - (2) A credential which holds a service account key JSON file, which - can then be used to impersonate any user in the workspace. - """ - self.credentials_json = credentials - - creds, new_creds_dict = get_google_drive_creds(credentials) - if KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: - self.service_account_creds = creds - self.service_account_email = credentials[ - DB_CREDENTIALS_DICT_DELEGATED_USER_KEY - ] - if self.service_account_email: - self.service_account_domain = self.service_account_email.split("@")[1] - else: - self.oauth_creds = creds + self.primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] + self.google_domain = self.primary_admin_email.split("@")[1] + + self.creds, new_creds_dict = get_google_drive_creds(credentials) return new_creds_dict - def get_admin_service( + def get_google_resource( self, service_name: str = "drive", service_version: str = "v3", user_email: str | None = None, ) -> Resource: - if self.service_account_creds: - creds = self.service_account_creds.with_subject( - user_email or self.service_account_email - ) + if isinstance(self.creds, ServiceAccountCredentials): + creds = self.creds.with_subject(user_email or self.primary_admin_email) service = build(service_name, service_version, credentials=creds) + elif isinstance(self.creds, OAuthCredentials): + service = build(service_name, service_version, credentials=self.creds) else: - service = build(service_name, service_version, credentials=self.oauth_creds) + raise PermissionError("No credentials found") return service @@ -264,9 +253,7 @@ def _get_folders_in_parent( personal_drive: bool = False, ) -> Iterator[GoogleDriveFileType]: # Follow shortcuts to folders - query = ( - f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" - ) + query = f"(mimeType = '{_DRIVE_FOLDER_TYPE}' or mimeType = '{_DRIVE_SHORTCUT_TYPE}')" if parent_id: query += f" and '{parent_id}' in parents" @@ -277,7 +264,7 @@ def _get_folders_in_parent( corpora="user" if personal_drive else "allDrives", supportsAllDrives=not personal_drive, includeItemsFromAllDrives=not personal_drive, - fields=FOLDER_FIELDS, + fields=_FOLDER_FIELDS, q=query, ): yield file @@ -290,7 +277,7 @@ def _get_files_in_parent( time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: - query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" + query = f"mimeType != '{_DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" if time_range_start is not None: time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" query += f" and modifiedTime >= '{time_start}'" @@ -304,7 +291,7 @@ def _get_files_in_parent( corpora="user" if personal_drive else "allDrives", supportsAllDrives=not personal_drive, includeItemsFromAllDrives=not personal_drive, - fields=SLIM_FILE_FIELDS if self.is_slim else FILE_FIELDS, + fields=_SLIM_FILE_FIELDS if self.is_slim else _FILE_FIELDS, q=query, ): yield file @@ -349,18 +336,13 @@ def _crawl_drive_for_files( ) def _get_all_user_emails(self) -> list[str]: - # if not self.service_account_creds: - # raise PermissionError("No service account credentials found") - - admin_creds = self.service_account_creds.with_subject( - self.service_account_email - ) - admin_service = build("admin", "directory_v1", credentials=admin_creds) + admin_service = self.get_google_resource("admin", "directory_v1") emails = [] for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", - domain=self.service_account_domain, + fields=_USER_FIELDS, + domain=self.google_domain, ): if email := user.get("primaryEmail"): emails.append(email) @@ -371,10 +353,9 @@ def _fetch_drive_items( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: - # admin_creds = self.service_account_creds.with_subject(self.service_account_email) - admin_drive_service = self.get_admin_service() + admin_drive_service = self.get_google_resource() - parent_ids = self.parent_ids + parent_ids = self.initial_parent_ids if not parent_ids: # if no parent ids are specified, get all shared drives using the admin account for drive in execute_paginated_retrieval( @@ -394,16 +375,16 @@ def _fetch_drive_items( time_range_start=start, time_range_end=end, ): - print(file) yield file - logger.info(f"Fetching personal files: {self.include_personal}") + # get all personal docs from each users' personal drive if self.include_personal: + logger.info("Checking My Drives for documents") all_user_emails = self._get_all_user_emails() for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") - user_creds = self.service_account_creds.with_subject(email) - user_drive_service = build("drive", "v3", credentials=user_creds) + user_drive_service = self.get_google_resource(user_email=email) + # we dont paginate here because there is only one root folder per user # https://developers.google.com/drive/api/guides/v2-to-v3-reference id = ( @@ -425,16 +406,13 @@ def _fetch_docs_from_drive( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: - # if self.oauth_creds is None and self.service_account_creds is None: - # raise PermissionError("No credentials found") - doc_batch = [] for file in self._fetch_drive_items( start=start, end=end, ): user_email = file.get("owners", [{}])[0].get("emailAddress") - service = self.get_admin_service(user_email=user_email) + service = self.get_google_resource(user_email=user_email) if doc := _convert_drive_item_to_document( file=file, service=service, @@ -447,12 +425,22 @@ def _fetch_docs_from_drive( yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: - yield from self._fetch_docs_from_drive() + try: + yield from self._fetch_docs_from_drive() + except Exception as e: + if _MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + raise e def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - yield from self._fetch_docs_from_drive(start, end) + try: + yield from self._fetch_docs_from_drive(start, end) + except Exception as e: + if _MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + raise e def _fetch_slim_docs_from_drive( self, @@ -485,4 +473,9 @@ def retrieve_all_slim_documents( end: SecondsSinceUnixEpoch | None = None, ) -> GenerateSlimDocumentOutput: self.is_slim = True - return self._fetch_slim_docs_from_drive(start, end) + try: + yield from self._fetch_slim_docs_from_drive(start, end) + except Exception as e: + if _MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + raise e diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 7a7853d31f9..490469e253a 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -32,7 +32,7 @@ "https://www.googleapis.com/auth/admin.directory.user.readonly", ] DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" -DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" +DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_delegated_user" def _build_frontend_google_drive_redirect() -> str: @@ -60,48 +60,15 @@ def get_google_drive_creds_for_authorized_user( return None -# def get_service_account_credentials( -# credentials: dict[str, str], -# scopes: list[str], -# ) -> ServiceAccountCredentials: -# service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] -# service_creds = _get_google_drive_creds_for_service_account( -# service_account_key_json_str=service_account_key_json_str, -# scopes=scopes, -# ) - -# # "Impersonate" a user if one is specified -# delegated_user_email = cast( -# str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) -# ) -# if delegated_user_email: -# service_creds = ( -# service_creds.with_subject(delegated_user_email) if service_creds else None -# ) -# return service_creds - - -# def get_oauth_credentials( -# credentials: dict[str, str], -# scopes: list[str] -# ) -> tuple[OAuthCredentials | None, dict[str, str] | None]: -# new_creds_dict = None -# access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) -# oauth_creds = get_google_drive_creds_for_authorized_user( -# token_json_str=access_token_json_str, scopes=scopes -# ) - -# # tell caller to update token stored in DB if it has changed -# # (e.g. the token has been refreshed) -# new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" -# if new_creds_json_str != access_token_json_str: -# new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} -# return oauth_creds, new_creds_dict - - def get_google_drive_creds( credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: + """Checks for two different types of credentials. + (1) A credential which holds a token acquired via a user going thorough + the Google OAuth flow. + (2) A credential which holds a service account key JSON file, which + can then be used to impersonate any user in the workspace. + """ oauth_creds = None service_creds = None new_creds_dict = None @@ -203,7 +170,7 @@ def build_service_account_creds( KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(), } if delegated_user_email: - credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = delegated_user_email return CredentialBase( credential_json=credential_dict, diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index 2c1351dccda..a08f01e7a8d 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -140,14 +140,14 @@ def gdrive_doc_sync( raise ValueError("Service account credentials not found") slim_docs = _get_slim_docs(cc_pair, google_drive_connector) - admin_service = google_drive_connector.get_admin_service() + admin_service = google_drive_connector.get_google_resource() for slim_doc in slim_docs: ext_access = _fetch_google_permissions_for_slim_doc( db_session=db_session, admin_service=admin_service, slim_doc=slim_doc, - company_domain=google_drive_connector.service_account_domain, + company_domain=google_drive_connector.google_domain, ) upsert_document_external_perms__no_commit( db_session=db_session, diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index b3f133ccbfd..837491382c4 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -23,13 +23,13 @@ def gdrive_group_sync( if google_drive_connector.service_account_creds is None: raise ValueError("Service account credentials not found") - admin_service = google_drive_connector.get_admin_service("admin", "directory_v1") + admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") danswer_groups: list[ExternalUserGroup] = [] for group in execute_paginated_retrieval( admin_service.groups().list, list_key="groups", - domain=google_drive_connector.service_account_domain, + domain=google_drive_connector.google_domain, fields="groups(email)", ): # The id is the group email From b4ac935132dfaa4c205551274de42b683e1d4638 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 29 Oct 2024 15:15:51 -0700 Subject: [PATCH 08/23] mypy --- backend/danswer/connectors/google_drive/connector.py | 5 +++-- backend/danswer/server/documents/connector.py | 4 +++- .../ee/danswer/external_permissions/google_drive/doc_sync.py | 5 +---- .../danswer/external_permissions/google_drive/group_sync.py | 3 --- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index a70f75fc027..ac7ddbd9818 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -224,8 +224,9 @@ def __init__( self._TRAVERSED_PARENT_IDS: set[str] = set() def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: - self.primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] - self.google_domain = self.primary_admin_email.split("@")[1] + primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] + self.google_domain = primary_admin_email.split("@")[1] + self.primary_admin_email = primary_admin_email self.creds, new_creds_dict = get_google_drive_creds(credentials) return new_creds_dict diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index afd41a296aa..f188671376f 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -44,6 +44,7 @@ get_google_drive_creds_for_authorized_user, ) from danswer.connectors.google_drive.connector_auth import get_service_account_key +from danswer.connectors.google_drive.connector_auth import GOOGLE_DRIVE_SCOPES from danswer.connectors.google_drive.connector_auth import ( update_credential_access_tokens, ) @@ -348,7 +349,8 @@ def check_drive_tokens( return AuthStatus(authenticated=False) token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY]) google_drive_creds = get_google_drive_creds_for_authorized_user( - token_json_str=token_json_str + token_json_str=token_json_str, + scopes=GOOGLE_DRIVE_SCOPES, ) if google_drive_creds is None: return AuthStatus(authenticated=False) diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index a08f01e7a8d..9f0f87bf2dd 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -22,7 +22,7 @@ def _get_slim_docs( cc_pair: ConnectorCredentialPair, google_drive_connector: GoogleDriveConnector, -) -> tuple[list[SlimDocument], GoogleDriveConnector]: +) -> list[SlimDocument]: current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() @@ -136,9 +136,6 @@ def gdrive_doc_sync( ) google_drive_connector.load_credentials(cc_pair.credential.credential_json) - if google_drive_connector.service_account_creds is None: - raise ValueError("Service account credentials not found") - slim_docs = _get_slim_docs(cc_pair, google_drive_connector) admin_service = google_drive_connector.get_google_resource() diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index 837491382c4..254352492f7 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -20,9 +20,6 @@ def gdrive_group_sync( ) google_drive_connector.load_credentials(cc_pair.credential.credential_json) - if google_drive_connector.service_account_creds is None: - raise ValueError("Service account credentials not found") - admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") danswer_groups: list[ExternalUserGroup] = [] From 2d4a6edec448da3bbd14296db8226ff363c786ac Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 29 Oct 2024 15:22:43 -0700 Subject: [PATCH 09/23] merge fixes --- backend/danswer/connectors/google_drive/connector.py | 1 - .../ee/danswer/external_permissions/google_drive/group_sync.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 3b53e622b3d..03528e8255b 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -16,7 +16,6 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA -from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder from danswer.connectors.google_drive.connector_auth import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index 480e28495fb..254352492f7 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -2,11 +2,9 @@ from danswer.connectors.google_drive.connector import execute_paginated_retrieval from danswer.connectors.google_drive.connector import GoogleDriveConnector -from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit From 0cc236100d09181f589a08be1956964165426209 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 29 Oct 2024 19:24:47 -0700 Subject: [PATCH 10/23] Refactor Google Drive connector --- .../connectors/google_drive/connector.py | 384 ++++-------------- .../connectors/google_drive/constants.py | 16 + .../connectors/google_drive/doc_conversion.py | 115 ++++++ .../connectors/google_drive/file_retrieval.py | 166 ++++++++ .../connectors/google_drive/google_utils.py | 35 ++ .../danswer/connectors/google_drive/models.py | 18 + .../google_drive/doc_sync.py | 3 +- .../google_drive/group_sync.py | 2 +- 8 files changed, 425 insertions(+), 314 deletions(-) create mode 100644 backend/danswer/connectors/google_drive/constants.py create mode 100644 backend/danswer/connectors/google_drive/doc_conversion.py create mode 100644 backend/danswer/connectors/google_drive/file_retrieval.py create mode 100644 backend/danswer/connectors/google_drive/google_utils.py create mode 100644 backend/danswer/connectors/google_drive/models.py diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 03528e8255b..6662d17e30e 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -1,53 +1,36 @@ -import io -from collections.abc import Callable from collections.abc import Iterator -from datetime import datetime -from datetime import timezone -from enum import Enum from typing import Any from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from googleapiclient.discovery import build # type: ignore from googleapiclient.discovery import Resource # type: ignore -from googleapiclient.errors import HttpError # type: ignore -from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import INDEX_BATCH_SIZE -from danswer.configs.constants import DocumentSource -from danswer.configs.constants import IGNORE_FOR_QA from danswer.connectors.google_drive.connector_auth import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from danswer.connectors.google_drive.connector_auth import get_google_drive_creds +from danswer.connectors.google_drive.constants import USER_FIELDS +from danswer.connectors.google_drive.doc_conversion import ( + convert_drive_item_to_document, +) +from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files +from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive +from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.models import GoogleDriveFileType from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SlimConnector -from danswer.connectors.models import Document -from danswer.connectors.models import Section from danswer.connectors.models import SlimDocument -from danswer.file_processing.extract_file_text import docx_to_text -from danswer.file_processing.extract_file_text import pptx_to_text -from danswer.file_processing.extract_file_text import read_pdf_file -from danswer.file_processing.unstructured import get_unstructured_api_key -from danswer.file_processing.unstructured import unstructured_to_text from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder logger = setup_logger() -_DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" -_DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" -_UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now - -_FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" -_SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" -_FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" -_USER_FIELDS = "nextPageToken, users(primaryEmail)" - # This is a substring of the error google returns when the user doesn't have the correct scopes. _MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" @@ -56,171 +39,32 @@ "You have upgraded Danswer without updating the Google Drive scopes. " f"Please refer to the documentation to learn how to update the scopes: {_SCOPE_DOC_URL}" ) -# these errors don't represent a failure in the connector, but simply files -# that can't / shouldn't be indexed -_ERRORS_TO_CONTINUE_ON = [ - "cannotExportFile", - "exportSizeLimitExceeded", - "cannotDownloadFile", -] _SLIM_BATCH_SIZE = 500 -class GDriveMimeType(str, Enum): - DOC = "application/vnd.google-apps.document" - SPREADSHEET = "application/vnd.google-apps.spreadsheet" - PDF = "application/pdf" - WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - PPT = "application/vnd.google-apps.presentation" - POWERPOINT = ( - "application/vnd.openxmlformats-officedocument.presentationml.presentation" - ) - PLAIN_TEXT = "text/plain" - MARKDOWN = "text/markdown" - - -GoogleDriveFileType = dict[str, Any] - -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=50, max_delay=30) - - -def execute_paginated_retrieval( - retrieval_function: Callable[..., Any], - list_key: str, - **kwargs: Any, -) -> Iterator[GoogleDriveFileType]: - """Execute a paginated retrieval from Google Drive API - Args: - retrieval_function: The specific list function to call (e.g., service.files().list) - **kwargs: Arguments to pass to the list function - """ - next_page_token = "" - while next_page_token is not None: - request_kwargs = kwargs.copy() - if next_page_token: - request_kwargs["pageToken"] = next_page_token - - results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() - - next_page_token = results.get("nextPageToken") - for item in results.get(list_key, []): - yield item - - -def extract_text(file: dict[str, str], service: Resource) -> str: - mime_type = file["mimeType"] - - if mime_type not in set(item.value for item in GDriveMimeType): - # Unsupported file types can still have a title, finding this way is still useful - return _UNSUPPORTED_FILE_TYPE_CONTENT - - if mime_type in [ - GDriveMimeType.DOC.value, - GDriveMimeType.PPT.value, - GDriveMimeType.SPREADSHEET.value, - ]: - export_mime_type = ( - "text/plain" - if mime_type != GDriveMimeType.SPREADSHEET.value - else "text/csv" - ) - return ( - service.files() - .export(fileId=file["id"], mimeType=export_mime_type) - .execute() - .decode("utf-8") - ) - elif mime_type in [ - GDriveMimeType.PLAIN_TEXT.value, - GDriveMimeType.MARKDOWN.value, - ]: - return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") - if mime_type in [ - GDriveMimeType.WORD_DOC.value, - GDriveMimeType.POWERPOINT.value, - GDriveMimeType.PDF.value, - ]: - response = service.files().get_media(fileId=file["id"]).execute() - if get_unstructured_api_key(): - return unstructured_to_text( - file=io.BytesIO(response), file_name=file.get("name", file["id"]) - ) - - if mime_type == GDriveMimeType.WORD_DOC.value: - return docx_to_text(file=io.BytesIO(response)) - elif mime_type == GDriveMimeType.PDF.value: - text, _ = read_pdf_file(file=io.BytesIO(response)) - return text - elif mime_type == GDriveMimeType.POWERPOINT.value: - return pptx_to_text(file=io.BytesIO(response)) - - return _UNSUPPORTED_FILE_TYPE_CONTENT - - -def _convert_drive_item_to_document( - file: GoogleDriveFileType, service: Resource -) -> Document | None: - try: - # Skip files that are shortcuts - if file.get("mimeType") == _DRIVE_SHORTCUT_TYPE: - logger.info("Ignoring Drive Shortcut Filetype") - return None - try: - text_contents = extract_text(file, service) or "" - except HttpError as e: - reason = e.error_details[0]["reason"] if e.error_details else e.reason - message = e.error_details[0]["message"] if e.error_details else e.reason - if e.status_code == 403 and reason in _ERRORS_TO_CONTINUE_ON: - logger.warning( - f"Could not export file '{file['name']}' due to '{message}', skipping..." - ) - return None - - raise - - return Document( - id=file["webViewLink"], - sections=[Section(link=file["webViewLink"], text=text_contents)], - source=DocumentSource.GOOGLE_DRIVE, - semantic_identifier=file["name"], - doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( - timezone.utc - ), - metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, - additional_info=file.get("id"), - ) - except Exception as e: - if not CONTINUE_ON_CONNECTOR_FAILURE: - raise e - - logger.exception("Ran into exception when pulling a file from Google Drive") - return None - - class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, - parent_urls: list[str] | None = None, - include_personal: bool | None = True, + include_shared_drives: bool = True, + include_my_drives: bool = True, + shared_drive_ids: list[str] | None = None, + my_drive_emails: list[str] | None = None, + folder_ids: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size - self.initial_parent_ids = [] - if parent_urls: - self.initial_parent_ids = [url.split("/")[-1] for url in parent_urls] - self.include_personal = include_personal or True + self.include_shared_drives = include_shared_drives + self.include_my_drives = include_my_drives + self.shared_drive_ids = shared_drive_ids or [] + self.my_drive_emails = my_drive_emails or [] + self.folder_ids = folder_ids or [] self.primary_admin_email: str | None = None self.google_domain: str | None = None self.creds: OAuthCredentials | ServiceAccountCredentials | None = None - self.is_slim: bool = False - self._TRAVERSED_PARENT_IDS: set[str] = set() def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: @@ -247,102 +91,13 @@ def get_google_resource( return service - def _get_folders_in_parent( - self, - service: Resource, - parent_id: str | None = None, - personal_drive: bool = False, - ) -> Iterator[GoogleDriveFileType]: - # Follow shortcuts to folders - query = f"(mimeType = '{_DRIVE_FOLDER_TYPE}' or mimeType = '{_DRIVE_SHORTCUT_TYPE}')" - - if parent_id: - query += f" and '{parent_id}' in parents" - - for file in execute_paginated_retrieval( - retrieval_function=service.files().list, - list_key="files", - corpora="user" if personal_drive else "allDrives", - supportsAllDrives=not personal_drive, - includeItemsFromAllDrives=not personal_drive, - fields=_FOLDER_FIELDS, - q=query, - ): - yield file - - def _get_files_in_parent( - self, - service: Resource, - parent_id: str, - personal_drive: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: - query = f"mimeType != '{_DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" - if time_range_start is not None: - time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" - query += f" and modifiedTime >= '{time_start}'" - if time_range_end is not None: - time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" - query += f" and modifiedTime <= '{time_stop}'" - - for file in execute_paginated_retrieval( - retrieval_function=service.files().list, - list_key="files", - corpora="user" if personal_drive else "allDrives", - supportsAllDrives=not personal_drive, - includeItemsFromAllDrives=not personal_drive, - fields=_SLIM_FILE_FIELDS if self.is_slim else _FILE_FIELDS, - q=query, - ): - yield file - - def _crawl_drive_for_files( - self, - service: Resource, - parent_id: str, - personal_drive: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: - """Gets all files matching the criteria specified by the args from Google Drive - in batches of size `batch_size`. - """ - if parent_id in self._TRAVERSED_PARENT_IDS: - logger.debug(f"Skipping subfolder since already traversed: {parent_id}") - return - - self._TRAVERSED_PARENT_IDS.add(parent_id) - - yield from self._get_files_in_parent( - service=service, - personal_drive=personal_drive, - time_range_start=time_range_start, - time_range_end=time_range_end, - parent_id=parent_id, - ) - - for subfolder in self._get_folders_in_parent( - service=service, - parent_id=parent_id, - personal_drive=personal_drive, - ): - logger.info("Fetching all files in subfolder: " + subfolder["name"]) - yield from self._crawl_drive_for_files( - service=service, - parent_id=subfolder["id"], - personal_drive=personal_drive, - time_range_start=time_range_start, - time_range_end=time_range_end, - ) - def _get_all_user_emails(self) -> list[str]: admin_service = self.get_google_resource("admin", "directory_v1") emails = [] for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", - fields=_USER_FIELDS, + fields=USER_FIELDS, domain=self.google_domain, ): if email := user.get("primaryEmail"): @@ -351,70 +106,77 @@ def _get_all_user_emails(self) -> list[str]: def _fetch_drive_items( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, + is_slim: bool, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: admin_drive_service = self.get_google_resource() - parent_ids = self.initial_parent_ids - if not parent_ids: - # if no parent ids are specified, get all shared drives using the admin account - for drive in execute_paginated_retrieval( - retrieval_function=admin_drive_service.drives().list, - list_key="drives", - useDomainAdminAccess=True, - fields="drives(id)", - ): - parent_ids.append(drive["id"]) - - # crawl all the shared parent ids for files - for parent_id in parent_ids: - for file in self._crawl_drive_for_files( - service=admin_drive_service, - parent_id=parent_id, - personal_drive=False, - time_range_start=start, - time_range_end=end, - ): - yield file + if self.include_shared_drives: + shared_drive_ids = self.shared_drive_ids + if not shared_drive_ids: + # if no parent ids are specified, get all shared drives using the admin account + for drive in execute_paginated_retrieval( + retrieval_function=admin_drive_service.drives().list, + list_key="drives", + useDomainAdminAccess=True, + fields="drives(id)", + ): + shared_drive_ids.append(drive["id"]) + + # crawl all the shared parent ids for files + for shared_drive_id in shared_drive_ids: + for file in get_files_in_shared_drive( + service=admin_drive_service, + drive_id=shared_drive_id, + is_slim=is_slim, + time_range_start=time_range_start, + time_range_end=time_range_end, + ): + yield file + + if self.folder_ids: + for folder_id in self.folder_ids: + yield from crawl_folders_for_files( + service=admin_drive_service, + parent_id=folder_id, + personal_drive=False, + time_range_start=time_range_start, + time_range_end=time_range_end, + ) # get all personal docs from each users' personal drive - if self.include_personal: - logger.info("Checking My Drives for documents") - all_user_emails = self._get_all_user_emails() + if self.include_my_drives: + all_user_emails = self.my_drive_emails + if not all_user_emails: + all_user_emails = self._get_all_user_emails() + for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") user_drive_service = self.get_google_resource(user_email=email) - # we dont paginate here because there is only one root folder per user - # https://developers.google.com/drive/api/guides/v2-to-v3-reference - id = ( - user_drive_service.files() - .get(fileId="root", fields="id") - .execute()["id"] - ) - - yield from self._crawl_drive_for_files( + yield from get_files_in_my_drive( service=user_drive_service, - parent_id=id, - personal_drive=True, - time_range_start=start, - time_range_end=end, + email=email, + is_slim=is_slim, + time_range_start=time_range_start, + time_range_end=time_range_end, ) - def _fetch_docs_from_drive( + def _extract_docs_from_google_drive( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: doc_batch = [] for file in self._fetch_drive_items( + is_slim=False, start=start, end=end, ): user_email = file.get("owners", [{}])[0].get("emailAddress") service = self.get_google_resource(user_email=user_email) - if doc := _convert_drive_item_to_document( + if doc := convert_drive_item_to_document( file=file, service=service, ): @@ -427,7 +189,7 @@ def _fetch_docs_from_drive( def load_from_state(self) -> GenerateDocumentsOutput: try: - yield from self._fetch_docs_from_drive() + yield from self._extract_docs_from_google_drive() except Exception as e: if _MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e @@ -437,19 +199,20 @@ def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: try: - yield from self._fetch_docs_from_drive(start, end) + yield from self._extract_docs_from_google_drive(start, end) except Exception as e: if _MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e raise e - def _fetch_slim_docs_from_drive( + def _extract_slim_docs_from_google_drive( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateSlimDocumentOutput: slim_batch = [] for file in self._fetch_drive_items( + is_slim=True, start=start, end=end, ): @@ -473,9 +236,8 @@ def retrieve_all_slim_documents( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateSlimDocumentOutput: - self.is_slim = True try: - yield from self._fetch_slim_docs_from_drive(start, end) + yield from self._extract_slim_docs_from_google_drive(start, end) except Exception as e: if _MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py new file mode 100644 index 00000000000..ead9a302332 --- /dev/null +++ b/backend/danswer/connectors/google_drive/constants.py @@ -0,0 +1,16 @@ +UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now +DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" +DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" + +FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" +SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" +FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" +USER_FIELDS = "nextPageToken, users(primaryEmail)" + +# these errors don't represent a failure in the connector, but simply files +# that can't / shouldn't be indexed +ERRORS_TO_CONTINUE_ON = [ + "cannotExportFile", + "exportSizeLimitExceeded", + "cannotDownloadFile", +] diff --git a/backend/danswer/connectors/google_drive/doc_conversion.py b/backend/danswer/connectors/google_drive/doc_conversion.py new file mode 100644 index 00000000000..688190c2267 --- /dev/null +++ b/backend/danswer/connectors/google_drive/doc_conversion.py @@ -0,0 +1,115 @@ +import io +from datetime import datetime +from datetime import timezone + +from googleapiclient.discovery import Resource # type: ignore +from googleapiclient.errors import HttpError # type: ignore + +from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE +from danswer.configs.constants import DocumentSource +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE +from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON +from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT +from danswer.connectors.google_drive.models import GDriveMimeType +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import docx_to_text +from danswer.file_processing.extract_file_text import pptx_to_text +from danswer.file_processing.extract_file_text import read_pdf_file +from danswer.file_processing.unstructured import get_unstructured_api_key +from danswer.file_processing.unstructured import unstructured_to_text +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _extract_text(file: dict[str, str], service: Resource) -> str: + mime_type = file["mimeType"] + + if mime_type not in set(item.value for item in GDriveMimeType): + # Unsupported file types can still have a title, finding this way is still useful + return UNSUPPORTED_FILE_TYPE_CONTENT + + if mime_type in [ + GDriveMimeType.DOC.value, + GDriveMimeType.PPT.value, + GDriveMimeType.SPREADSHEET.value, + ]: + export_mime_type = ( + "text/plain" + if mime_type != GDriveMimeType.SPREADSHEET.value + else "text/csv" + ) + return ( + service.files() + .export(fileId=file["id"], mimeType=export_mime_type) + .execute() + .decode("utf-8") + ) + elif mime_type in [ + GDriveMimeType.PLAIN_TEXT.value, + GDriveMimeType.MARKDOWN.value, + ]: + return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") + if mime_type in [ + GDriveMimeType.WORD_DOC.value, + GDriveMimeType.POWERPOINT.value, + GDriveMimeType.PDF.value, + ]: + response = service.files().get_media(fileId=file["id"]).execute() + if get_unstructured_api_key(): + return unstructured_to_text( + file=io.BytesIO(response), file_name=file.get("name", file["id"]) + ) + + if mime_type == GDriveMimeType.WORD_DOC.value: + return docx_to_text(file=io.BytesIO(response)) + elif mime_type == GDriveMimeType.PDF.value: + text, _ = read_pdf_file(file=io.BytesIO(response)) + return text + elif mime_type == GDriveMimeType.POWERPOINT.value: + return pptx_to_text(file=io.BytesIO(response)) + + return UNSUPPORTED_FILE_TYPE_CONTENT + + +def convert_drive_item_to_document( + file: GoogleDriveFileType, service: Resource +) -> Document | None: + try: + # Skip files that are shortcuts + if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: + logger.info("Ignoring Drive Shortcut Filetype") + return None + try: + text_contents = _extract_text(file, service) or "" + except HttpError as e: + reason = e.error_details[0]["reason"] if e.error_details else e.reason + message = e.error_details[0]["message"] if e.error_details else e.reason + if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + logger.warning( + f"Could not export file '{file['name']}' due to '{message}', skipping..." + ) + return None + + raise + + return Document( + id=file["webViewLink"], + sections=[Section(link=file["webViewLink"], text=text_contents)], + source=DocumentSource.GOOGLE_DRIVE, + semantic_identifier=file["name"], + doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( + timezone.utc + ), + metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + additional_info=file.get("id"), + ) + except Exception as e: + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise e + + logger.exception("Ran into exception when pulling a file from Google Drive") + return None diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py new file mode 100644 index 00000000000..e9ad142b7f7 --- /dev/null +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -0,0 +1,166 @@ +from collections.abc import Iterator +from datetime import datetime + +from googleapiclient.discovery import Resource # type: ignore + +from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE +from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE +from danswer.connectors.google_drive.constants import FILE_FIELDS +from danswer.connectors.google_drive.constants import FOLDER_FIELDS +from danswer.connectors.google_drive.constants import SLIM_FILE_FIELDS +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _generate_time_range_filter( + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, +) -> str: + time_range_filter = "" + if time_range_start is not None: + time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" + time_range_filter += f" and modifiedTime >= '{time_start}'" + if time_range_end is not None: + time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" + time_range_filter += f" and modifiedTime <= '{time_stop}'" + return time_range_filter + + +def _get_folders_in_parent( + service: Resource, + parent_id: str | None = None, + personal_drive: bool = False, +) -> Iterator[GoogleDriveFileType]: + # Follow shortcuts to folders + query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" + + if parent_id: + query += f" and '{parent_id}' in parents" + + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, + fields=FOLDER_FIELDS, + q=query, + ): + yield file + + +def _get_files_in_parent( + service: Resource, + parent_id: str, + personal_drive: bool, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, + is_slim: bool = False, +) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" + query += _generate_time_range_filter(time_range_start, time_range_end) + + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, + q=query, + ): + yield file + + +_TRAVERSED_PARENT_IDS: set[str] = set() + + +def crawl_folders_for_files( + service: Resource, + parent_id: str, + personal_drive: bool, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + """ + This one can start crawling from any folder. It is slower though. + """ + if parent_id in _TRAVERSED_PARENT_IDS: + logger.debug(f"Skipping subfolder since already traversed: {parent_id}") + return + + _TRAVERSED_PARENT_IDS.add(parent_id) + + yield from _get_files_in_parent( + service=service, + personal_drive=personal_drive, + time_range_start=time_range_start, + time_range_end=time_range_end, + parent_id=parent_id, + ) + + for subfolder in _get_folders_in_parent( + service=service, + parent_id=parent_id, + personal_drive=personal_drive, + ): + logger.info("Fetching all files in subfolder: " + subfolder["name"]) + yield from crawl_folders_for_files( + service=service, + parent_id=subfolder["id"], + personal_drive=personal_drive, + time_range_start=time_range_start, + time_range_end=time_range_end, + ) + + +def get_files_in_shared_drive( + service: Resource, + drive_id: str, + is_slim: bool = False, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + query += _generate_time_range_filter(time_range_start, time_range_end) + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="drive", + drive_id=drive_id, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + fields=FILE_FIELDS if is_slim else SLIM_FILE_FIELDS, + q=query, + ): + yield file + + +def get_files_in_my_drive( + service: Resource, + email: str, + is_slim: bool = False, + time_range_start: SecondsSinceUnixEpoch | None = None, + time_range_end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners" + query += _generate_time_range_filter(time_range_start, time_range_end) + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user", + fields=FILE_FIELDS if is_slim else SLIM_FILE_FIELDS, + q=query, + ): + yield file + + +# Just in case we need to get the root folder id +def get_root_folder_id(service: Resource) -> str: + # we dont paginate here because there is only one root folder per user + # https://developers.google.com/drive/api/guides/v2-to-v3-reference + return service.files().get(fileId="root", fields="id").execute()["id"] diff --git a/backend/danswer/connectors/google_drive/google_utils.py b/backend/danswer/connectors/google_drive/google_utils.py new file mode 100644 index 00000000000..5f772e5ad63 --- /dev/null +++ b/backend/danswer/connectors/google_drive/google_utils.py @@ -0,0 +1,35 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any + +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.utils.retry_wrapper import retry_builder + + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=50, max_delay=30) + + +def execute_paginated_retrieval( + retrieval_function: Callable, + list_key: str, + **kwargs: Any, +) -> Iterator[GoogleDriveFileType]: + """Execute a paginated retrieval from Google Drive API + Args: + retrieval_function: The specific list function to call (e.g., service.files().list) + **kwargs: Arguments to pass to the list function + """ + next_page_token = "" + while next_page_token is not None: + request_kwargs = kwargs.copy() + if next_page_token: + request_kwargs["pageToken"] = next_page_token + + results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() + + next_page_token = results.get("nextPageToken") + for item in results.get(list_key, []): + yield item diff --git a/backend/danswer/connectors/google_drive/models.py b/backend/danswer/connectors/google_drive/models.py new file mode 100644 index 00000000000..5bb06f3c206 --- /dev/null +++ b/backend/danswer/connectors/google_drive/models.py @@ -0,0 +1,18 @@ +from enum import Enum +from typing import Any + + +class GDriveMimeType(str, Enum): + DOC = "application/vnd.google-apps.document" + SPREADSHEET = "application/vnd.google-apps.spreadsheet" + PDF = "application/pdf" + WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + PPT = "application/vnd.google-apps.presentation" + POWERPOINT = ( + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ) + PLAIN_TEXT = "text/plain" + MARKDOWN = "text/markdown" + + +GoogleDriveFileType = dict[str, Any] diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index 5ac6c31f2b4..2b2a83064a7 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -6,13 +6,12 @@ from sqlalchemy.orm import Session from danswer.access.models import ExternalAccess -from danswer.connectors.google_drive.connector import execute_paginated_retrieval from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index 254352492f7..c3afa962392 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session -from danswer.connectors.google_drive.connector import execute_paginated_retrieval from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger From 4ce88097a41d7c7b0be299eef64fbe000f1ec2c8 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 29 Oct 2024 19:45:45 -0700 Subject: [PATCH 11/23] finished backend --- .../connectors/google_drive/connector.py | 16 +++---- .../connectors/google_drive/file_retrieval.py | 48 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 6662d17e30e..e8e00712396 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -107,8 +107,8 @@ def _get_all_user_emails(self) -> list[str]: def _fetch_drive_items( self, is_slim: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: admin_drive_service = self.get_google_resource() @@ -130,8 +130,8 @@ def _fetch_drive_items( service=admin_drive_service, drive_id=shared_drive_id, is_slim=is_slim, - time_range_start=time_range_start, - time_range_end=time_range_end, + start=start, + end=end, ): yield file @@ -141,8 +141,8 @@ def _fetch_drive_items( service=admin_drive_service, parent_id=folder_id, personal_drive=False, - time_range_start=time_range_start, - time_range_end=time_range_end, + start=start, + end=end, ) # get all personal docs from each users' personal drive @@ -159,8 +159,8 @@ def _fetch_drive_items( service=user_drive_service, email=email, is_slim=is_slim, - time_range_start=time_range_start, - time_range_end=time_range_end, + start=start, + end=end, ) def _extract_docs_from_google_drive( diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index e9ad142b7f7..db296e49c38 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -17,15 +17,15 @@ def _generate_time_range_filter( - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, ) -> str: time_range_filter = "" - if time_range_start is not None: - time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" + if start is not None: + time_start = datetime.utcfromtimestamp(start).isoformat() + "Z" time_range_filter += f" and modifiedTime >= '{time_start}'" - if time_range_end is not None: - time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" + if end is not None: + time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z" time_range_filter += f" and modifiedTime <= '{time_stop}'" return time_range_filter @@ -57,12 +57,12 @@ def _get_files_in_parent( service: Resource, parent_id: str, personal_drive: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, is_slim: bool = False, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" - query += _generate_time_range_filter(time_range_start, time_range_end) + query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( retrieval_function=service.files().list, @@ -83,8 +83,8 @@ def crawl_folders_for_files( service: Resource, parent_id: str, personal_drive: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: """ This one can start crawling from any folder. It is slower though. @@ -98,8 +98,8 @@ def crawl_folders_for_files( yield from _get_files_in_parent( service=service, personal_drive=personal_drive, - time_range_start=time_range_start, - time_range_end=time_range_end, + start=start, + end=end, parent_id=parent_id, ) @@ -113,8 +113,8 @@ def crawl_folders_for_files( service=service, parent_id=subfolder["id"], personal_drive=personal_drive, - time_range_start=time_range_start, - time_range_end=time_range_end, + start=start, + end=end, ) @@ -122,19 +122,19 @@ def get_files_in_shared_drive( service: Resource, drive_id: str, is_slim: bool = False, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" - query += _generate_time_range_filter(time_range_start, time_range_end) + query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", corpora="drive", - drive_id=drive_id, + driveId=drive_id, supportsAllDrives=True, includeItemsFromAllDrives=True, - fields=FILE_FIELDS if is_slim else SLIM_FILE_FIELDS, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=query, ): yield file @@ -144,16 +144,16 @@ def get_files_in_my_drive( service: Resource, email: str, is_slim: bool = False, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners" - query += _generate_time_range_filter(time_range_start, time_range_end) + query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", corpora="user", - fields=FILE_FIELDS if is_slim else SLIM_FILE_FIELDS, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=query, ): yield file From 8e3623acf997b62673f548ed717ae484ef095f8f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 07:48:18 -0700 Subject: [PATCH 12/23] auth changes --- .../connectors/google_drive/connector.py | 33 +++++++--------- .../connectors/google_drive/connector_auth.py | 34 ++++++++++++++--- .../connectors/google_drive/constants.py | 13 +++++++ backend/danswer/server/documents/connector.py | 38 ++++++++++--------- backend/danswer/server/documents/models.py | 6 +-- web/src/lib/connectors/credentials.ts | 5 ++- 6 files changed, 82 insertions(+), 47 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index e8e00712396..801435606b8 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -11,6 +11,9 @@ DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from danswer.connectors.google_drive.connector_auth import get_google_drive_creds +from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR +from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS +from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE from danswer.connectors.google_drive.constants import USER_FIELDS from danswer.connectors.google_drive.doc_conversion import ( convert_drive_item_to_document, @@ -31,23 +34,13 @@ logger = setup_logger() -# This is a substring of the error google returns when the user doesn't have the correct scopes. -_MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" - -_SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" -_ONYX_SCOPE_INSTRUCTIONS = ( - "You have upgraded Danswer without updating the Google Drive scopes. " - f"Please refer to the documentation to learn how to update the scopes: {_SCOPE_DOC_URL}" -) -_SLIM_BATCH_SIZE = 500 - class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, include_shared_drives: bool = True, - include_my_drives: bool = True, shared_drive_ids: list[str] | None = None, + include_my_drives: bool = True, my_drive_emails: list[str] | None = None, folder_ids: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, @@ -55,9 +48,11 @@ def __init__( self.batch_size = batch_size self.include_shared_drives = include_shared_drives - self.include_my_drives = include_my_drives self.shared_drive_ids = shared_drive_ids or [] + + self.include_my_drives = include_my_drives self.my_drive_emails = my_drive_emails or [] + self.folder_ids = folder_ids or [] self.primary_admin_email: str | None = None @@ -191,8 +186,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: try: yield from self._extract_docs_from_google_drive() except Exception as e: - if _MISSING_SCOPES_ERROR_STR in str(e): - raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e def poll_source( @@ -201,8 +196,8 @@ def poll_source( try: yield from self._extract_docs_from_google_drive(start, end) except Exception as e: - if _MISSING_SCOPES_ERROR_STR in str(e): - raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e def _extract_slim_docs_from_google_drive( @@ -226,7 +221,7 @@ def _extract_slim_docs_from_google_drive( }, ) ) - if len(slim_batch) >= _SLIM_BATCH_SIZE: + if len(slim_batch) >= SLIM_BATCH_SIZE: yield slim_batch slim_batch = [] yield slim_batch @@ -239,6 +234,6 @@ def retrieve_all_slim_documents( try: yield from self._extract_slim_docs_from_google_drive(start, end) except Exception as e: - if _MISSING_SCOPES_ERROR_STR in str(e): - raise PermissionError(_ONYX_SCOPE_INSTRUCTIONS) from e + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 490469e253a..cb42f5e09aa 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -8,6 +8,7 @@ from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore +from googleapiclient.discovery import build from sqlalchemy.orm import Session from danswer.configs.app_configs import WEB_DOMAIN @@ -15,6 +16,8 @@ from danswer.configs.constants import KV_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY +from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR +from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS from danswer.db.credentials import update_credential_json from danswer.db.models import User from danswer.key_value_store.factory import get_kv_store @@ -32,7 +35,7 @@ "https://www.googleapis.com/auth/admin.directory.user.readonly", ] DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" -DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_delegated_user" +DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_primary_admin" def _build_frontend_google_drive_redirect() -> str: @@ -153,7 +156,28 @@ def update_credential_access_tokens( flow.fetch_token(code=auth_code) creds = flow.credentials token_json_str = creds.to_json() - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str} + + # Get user email from Google API so we know who + # the primary admin is for this connector + try: + admin_service = build("drive", "v3", credentials=creds) + user_info = ( + admin_service.about() + .get( + fields="user(emailAddress)", + ) + .execute() + ) + email = user_info.get("user", {}).get("emailAddress") + except Exception as e: + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e + raise e + + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, + } if not update_credential_json(credential_id, new_creds_dict, user, db_session): return None @@ -162,15 +186,15 @@ def update_credential_access_tokens( def build_service_account_creds( source: DocumentSource, - delegated_user_email: str | None = None, + primary_admin_email: str | None = None, ) -> CredentialBase: service_account_key = get_service_account_key() credential_dict = { KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(), } - if delegated_user_email: - credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = delegated_user_email + if primary_admin_email: + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email return CredentialBase( credential_json=credential_dict, diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index ead9a302332..1b087f28c5a 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -14,3 +14,16 @@ "exportSizeLimitExceeded", "cannotDownloadFile", ] + +# Error message substrings +MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" + +# Documentation and error messages +SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" +ONYX_SCOPE_INSTRUCTIONS = ( + "You have upgraded Danswer without updating the Google Drive scopes. " + f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}" +) + +# Batch sizes +SLIM_BATCH_SIZE = 500 diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index f188671376f..154cae2e2b5 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -9,6 +9,7 @@ from fastapi import Request from fastapi import Response from fastapi import UploadFile +from google.oauth2.credentials import Credentials from pydantic import BaseModel from sqlalchemy.orm import Session @@ -295,7 +296,7 @@ def upsert_service_account_credential( try: credential_base = build_service_account_creds( DocumentSource.GOOGLE_DRIVE, - delegated_user_email=service_account_credential_request.google_drive_delegated_user, + primary_admin_email=service_account_credential_request.google_drive_primary_admin, ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -321,7 +322,7 @@ def upsert_gmail_service_account_credential( try: credential_base = build_service_account_creds( DocumentSource.GMAIL, - delegated_user_email=service_account_credential_request.gmail_delegated_user, + primary_admin_email=service_account_credential_request.gmail_delegated_user, ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -357,18 +358,18 @@ def check_drive_tokens( return AuthStatus(authenticated=True) -@router.get("/admin/connector/google-drive/authorize/{credential_id}") -def admin_google_drive_auth( - response: Response, credential_id: str, _: User = Depends(current_admin_user) -) -> AuthUrl: - # set a cookie that we can read in the callback (used for `verify_csrf`) - response.set_cookie( - key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME, - value=credential_id, - httponly=True, - max_age=600, - ) - return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id))) +# @router.get("/admin/connector/google-drive/authorize/{credential_id}") +# def admin_google_drive_auth( +# response: Response, credential_id: str, _: User = Depends(current_admin_user) +# ) -> AuthUrl: +# # set a cookie that we can read in the callback (used for `verify_csrf`) +# response.set_cookie( +# key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME, +# value=credential_id, +# httponly=True, +# max_age=600, +# ) +# return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id))) @router.post("/admin/connector/file/upload") @@ -953,10 +954,11 @@ def google_drive_callback( ) credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) - if ( - update_credential_access_tokens(callback.code, credential_id, user, db_session) - is None - ): + + credentials: Credentials | None = update_credential_access_tokens( + callback.code, credential_id, user, db_session + ) + if credentials is None: raise HTTPException( status_code=500, detail="Unable to fetch Google Drive access tokens" ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index fcbc0a76a12..e45d6eabff0 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -377,16 +377,16 @@ class GoogleServiceAccountKey(BaseModel): class GoogleServiceAccountCredentialRequest(BaseModel): - google_drive_delegated_user: str | None = None # email of user to impersonate + google_drive_primary_admin: str | None = None # email of user to impersonate gmail_delegated_user: str | None = None # email of user to impersonate @model_validator(mode="after") def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest": - if (self.google_drive_delegated_user is None) == ( + if (self.google_drive_primary_admin is None) == ( self.gmail_delegated_user is None ): raise ValueError( - "Exactly one of google_drive_delegated_user or gmail_delegated_user must be set" + "Exactly one of google_drive_primary_admin or gmail_delegated_user must be set" ) return self diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index 532f8f6de76..73c788d3a96 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -58,6 +58,7 @@ export interface GmailCredentialJson { export interface GoogleDriveCredentialJson { google_drive_tokens: string; + google_drive_primary_admin: string; } export interface GmailServiceAccountCredentialJson { @@ -67,7 +68,7 @@ export interface GmailServiceAccountCredentialJson { export interface GoogleDriveServiceAccountCredentialJson { google_drive_service_account_key: string; - google_drive_delegated_user: string; + google_drive_primary_admin: string; } export interface SlabCredentialJson { @@ -331,7 +332,7 @@ export const credentialDisplayNames: Record = { // Google Drive Service Account google_drive_service_account_key: "Google Drive Service Account Key", - google_drive_delegated_user: "Google Drive Delegated User", + google_drive_primary_admin: "Google Drive Delegated User", // Slab slab_bot_token: "Slab Bot Token", From b7bacd0b89b601643bb91f85e0f5677393734255 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 11:38:30 -0700 Subject: [PATCH 13/23] if its stupid but it works, its not stupid --- .../connectors/google_drive/connector.py | 34 ++++-- .../connectors/google_drive/connector_auth.py | 9 +- backend/danswer/server/documents/connector.py | 2 +- .../[connector]/AddConnectorPage.tsx | 6 + .../pages/DynamicConnectorCreationForm.tsx | 115 +++++++++++++++--- .../[connector]/pages/gdrive/Credential.tsx | 11 +- .../pages/gdrive/GoogleDrivePage.tsx | 4 +- .../admin/connectors/AccessTypeForm.tsx | 2 +- web/src/lib/connectors/connectors.tsx | 70 ++++++++--- 9 files changed, 198 insertions(+), 55 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 801435606b8..39394456c81 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -35,25 +35,33 @@ logger = setup_logger() +def _get_string_list_from_comma_separated_string(string: str | None) -> list[str]: + return string.split(",") if string else [] + + class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, include_shared_drives: bool = True, - shared_drive_ids: list[str] | None = None, + shared_drive_ids: str | None = None, include_my_drives: bool = True, - my_drive_emails: list[str] | None = None, - folder_ids: list[str] | None = None, + my_drive_emails: str | None = None, + folder_ids: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.include_shared_drives = include_shared_drives - self.shared_drive_ids = shared_drive_ids or [] + self.shared_drive_ids = _get_string_list_from_comma_separated_string( + shared_drive_ids + ) self.include_my_drives = include_my_drives - self.my_drive_emails = my_drive_emails or [] + self.my_drive_emails = _get_string_list_from_comma_separated_string( + my_drive_emails + ) - self.folder_ids = folder_ids or [] + self.folder_ids = _get_string_list_from_comma_separated_string(folder_ids) self.primary_admin_email: str | None = None self.google_domain: str | None = None @@ -142,9 +150,17 @@ def _fetch_drive_items( # get all personal docs from each users' personal drive if self.include_my_drives: - all_user_emails = self.my_drive_emails - if not all_user_emails: - all_user_emails = self._get_all_user_emails() + all_user_emails: set[str] = set(self.my_drive_emails or []) + + # If using service account and no emails specified, fetch all users + if not all_user_emails and isinstance( + self.creds, ServiceAccountCredentials + ): + all_user_emails = set(self._get_all_user_emails()) + + # Always include the primary admin email + if self.primary_admin_email: + all_user_emails.add(self.primary_admin_email) for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index cb42f5e09aa..80cbda6772a 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore -from googleapiclient.discovery import build +from googleapiclient.discovery import build # type: ignore from sqlalchemy.orm import Session from danswer.configs.app_configs import WEB_DOMAIN @@ -85,7 +85,12 @@ def get_google_drive_creds( # (e.g. the token has been refreshed) new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" if new_creds_json_str != access_token_json_str: - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ + DB_CREDENTIALS_PRIMARY_ADMIN_KEY + ], + } elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 154cae2e2b5..60e973c36f3 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -9,7 +9,7 @@ from fastapi import Request from fastapi import Response from fastapi import UploadFile -from google.oauth2.credentials import Credentials +from google.oauth2.credentials import Credentials # type: ignore from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index c2e903e2776..2a1f3303885 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -431,6 +431,12 @@ export default function AddConnector({ setSelectedFiles={setSelectedFiles} selectedFiles={selectedFiles} connector={connector} + currentCredential={ + currentCredential || + liveGDriveCredential || + liveGmailCredential || + null + } /> )} diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 85237df2c7d..32157267a05 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -13,6 +13,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm"; import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector"; import { ConfigurableSources } from "@/lib/types"; +import { Credential } from "@/lib/connectors/credentials"; +import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection"; export interface DynamicConnectionFormProps { config: ConnectionConfiguration; @@ -20,19 +22,44 @@ export interface DynamicConnectionFormProps { setSelectedFiles: Dispatch>; values: any; connector: ConfigurableSources; + currentCredential: Credential | null; } -const DynamicConnectionForm: FC = ({ - config, +interface RenderFieldProps { + field: any; + values: any; + selectedFiles: File[]; + setSelectedFiles: Dispatch>; + connector: ConfigurableSources; + currentCredential: Credential | null; +} + +const RenderField: FC = ({ + field, + values, selectedFiles, setSelectedFiles, - values, connector, + currentCredential, }) => { - const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + if ( + field.visibleCondition && + !field.visibleCondition(values, currentCredential) + ) { + return null; + } + + const label = + typeof field.label === "function" + ? field.label(currentCredential) + : field.label; + const description = + typeof field.description === "function" + ? field.description(currentCredential) + : field.description; - const renderField = (field: any) => ( -
+ const fieldContent = ( + <> {field.type === "file" ? ( = ({ ) : field.type === "zip" ? ( ) : field.type === "list" ? ( - + ) : field.type === "select" ? ( ) : field.type === "number" ? ( ) : field.type === "checkbox" ? ( ) : ( )} -
+ ); + if ( + field.visibleCondition && + field.visibleCondition(values, currentCredential) + ) { + return ( + + {fieldContent} + + ); + } else { + return
{fieldContent}
; + } +}; + +const DynamicConnectionForm: FC = ({ + config, + selectedFiles, + setSelectedFiles, + values, + connector, + currentCredential, +}) => { + const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + return ( <>

{config.description}

@@ -97,7 +148,20 @@ const DynamicConnectionForm: FC = ({ name={"name"} /> - {config.values.map((field) => !field.hidden && renderField(field))} + {config.values.map( + (field) => + !field.hidden && ( + + ) + )} @@ -108,7 +172,18 @@ const DynamicConnectionForm: FC = ({ showAdvancedOptions={showAdvancedOptions} setShowAdvancedOptions={setShowAdvancedOptions} /> - {showAdvancedOptions && config.advanced_values.map(renderField)} + {showAdvancedOptions && + config.advanced_values.map((field) => ( + + ))} )} diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 4134cde4c2a..b45bd9211ab 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -309,7 +309,7 @@ interface DriveCredentialSectionProps { connectorExists: boolean; } -export const DriveOAuthSection = ({ +export const DriveAuthSection = ({ googleDrivePublicCredential, googleDriveServiceAccountCredential, serviceAccountKeyData, @@ -367,10 +367,10 @@ export const DriveOAuthSection = ({ ( diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index d8a14db03a1..f7a553f8303 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -12,7 +12,7 @@ import { useConnectorCredentialIndexingStatus, } from "@/lib/hooks"; import { Title } from "@tremor/react"; -import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential"; +import { DriveJsonUploadSection, DriveAuthSection } from "./Credential"; import { Credential, GoogleDriveCredentialJson, @@ -135,7 +135,7 @@ const GDriveMain = ({}: {}) => { Step 2: Authenticate with Danswer - | null) => string); name: string; - description?: string; + description?: + | string + | ((currentCredential: Credential | null) => string); query?: string; optional?: boolean; hidden?: boolean; + visibleCondition?: ( + values: any, + currentCredential: Credential | null + ) => boolean; } export interface SelectOption extends Option { @@ -202,23 +209,58 @@ export const connectorConfigs: Record< }, google_drive: { description: "Configure Google Drive connector", - values: [], - advanced_values: [ + values: [ { - type: "list", - query: "Enter the URLs of the shared folders or drives to index:", - label: "Parent URLs To Index", - name: "parent_urls", + type: "checkbox", + description: "Include shared drives?", + label: "Include Shared Drives", + name: "include_shared_drives", + optional: true, + default: true, + }, + { + type: "text", + description: + "Enter a comma separated list of the IDs of the shared drives to index. Leave blank to index all shared drives.", + label: "Shared Drive IDs", + name: "shared_drive_ids", + visibleCondition: (values) => values.include_shared_drives, optional: true, }, { type: "checkbox", - query: - "Include personal drives? (Note: This should only be used if you use permissions sync)", - label: "Include personal", - name: "include_personal", - optional: false, - default: false, + label: (currentCredential) => + currentCredential?.credential_json?.google_drive_tokens + ? "Include My Drive?" + : "Include Everyone's My Drive?", + description: (currentCredential) => + currentCredential?.credential_json?.google_drive_tokens + ? "This will let Danswer index everything in your My Drive." + : "This will let Danswer index everything in everyone's My Drives.", + name: "include_my_drives", + optional: true, + default: true, + }, + { + type: "text", + description: + "Enter a comma separated list of the emails of the users whose MyDrive you want to index. Leave blank to index all MyDrives.", + label: "My Drive Emails", + name: "my_drive_emails", + visibleCondition: (values, currentCredential) => + values.include_my_drives && + !currentCredential?.credential_json?.google_drive_tokens, + optional: true, + }, + ], + advanced_values: [ + { + type: "text", + description: + "Enter a comma separated list of the IDs of the folders to index:", + label: "Folder IDs", + name: "folder_ids", + optional: true, }, ], }, From 16c353091b32b223485bc3fc1a55fbef154f343d Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 11:56:21 -0700 Subject: [PATCH 14/23] npm run dev fixes --- .../pages/ConnectorInput/ListInput.tsx | 16 ++++++++++------ .../pages/DynamicConnectorCreationForm.tsx | 2 +- .../[connector]/pages/gdrive/Credential.tsx | 2 +- .../admin/connectors/ConnectorTitle.tsx | 15 --------------- 4 files changed, 12 insertions(+), 23 deletions(-) diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx index 956e0c24597..05deec472a6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx @@ -4,18 +4,22 @@ import { TextArrayField } from "@/components/admin/connectors/Field"; import { useFormikContext } from "formik"; interface ListInputProps { - field: ListOption; + name: string; + label: string | ((credential: any) => string); + description: string | ((credential: any) => string); } -const ListInput: React.FC = ({ field }) => { +const ListInput: React.FC = ({ name, label, description }) => { const { values } = useFormikContext(); return ( ); }; diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 32157267a05..dc7f75e0632 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -74,7 +74,7 @@ const RenderField: FC = ({ description={description} /> ) : field.type === "list" ? ( - + ) : field.type === "select" ? (

- When using a Google Drive Service Account, you must speicify the email + When using a Google Drive Service Account, you must specify the email of the primary admin that you would like the service account to impersonate.
diff --git a/web/src/components/admin/connectors/ConnectorTitle.tsx b/web/src/components/admin/connectors/ConnectorTitle.tsx index 269c72e905f..6e2da252aec 100644 --- a/web/src/components/admin/connectors/ConnectorTitle.tsx +++ b/web/src/components/admin/connectors/ConnectorTitle.tsx @@ -64,21 +64,6 @@ export const ConnectorTitle = ({ "Jira Project URL", typedConnector.connector_specific_config.jira_project_url ); - } else if (connector.source === "google_drive") { - const typedConnector = connector as Connector; - if ( - typedConnector.connector_specific_config?.folder_paths && - typedConnector.connector_specific_config?.folder_paths.length > 0 - ) { - additionalMetadata.set( - "Folders", - typedConnector.connector_specific_config.folder_paths.join(", ") - ); - } - - if (!isPublic && owner) { - additionalMetadata.set("Owner", owner); - } } else if (connector.source === "slack") { const typedConnector = connector as Connector; if ( From ff82ee5e01f5a5adc60bbfe66044cee0a99d3b72 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 14:10:55 -0700 Subject: [PATCH 15/23] addressed change requests --- .../connectors/google_drive/connector.py | 84 ++++++++++++------- .../connectors/google_drive/constants.py | 2 +- .../connectors/google_drive/file_retrieval.py | 18 ++++ .../pages/DynamicConnectorCreationForm.tsx | 1 + .../[connector]/pages/gdrive/Credential.tsx | 11 ++- .../pages/gdrive/GoogleDrivePage.tsx | 3 +- web/src/lib/connectors/connectors.tsx | 11 +-- 7 files changed, 91 insertions(+), 39 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 39394456c81..0585904ec60 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -13,6 +13,7 @@ from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS +from danswer.connectors.google_drive.constants import SCOPE_DOC_URL from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE from danswer.connectors.google_drive.constants import USER_FIELDS from danswer.connectors.google_drive.doc_conversion import ( @@ -35,41 +36,66 @@ logger = setup_logger() -def _get_string_list_from_comma_separated_string(string: str | None) -> list[str]: +def _extract_str_list_from_comma_str(string: str | None) -> list[str]: return string.split(",") if string else [] +def _extract_ids_from_urls(urls: list[str]) -> list[str]: + return [url.split("/")[-1] for url in urls] + + class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, include_shared_drives: bool = True, - shared_drive_ids: str | None = None, + shared_drive_urls: str | None = None, include_my_drives: bool = True, my_drive_emails: str | None = None, - folder_ids: str | None = None, + shared_folder_urls: str | None = None, batch_size: int = INDEX_BATCH_SIZE, + # OLD PARAMETERS + folder_paths: list[str] | None = None, + include_shared: bool | None = None, + follow_shortcuts: bool | None = None, + only_org_public: bool | None = None, + continue_on_failure: bool | None = None, ) -> None: + # Check for old input parameters + if ( + folder_paths is not None + or include_shared is not None + or follow_shortcuts is not None + or only_org_public is not None + or continue_on_failure is not None + ): + logger.exception( + "Google Drive connector received old input parameters. " + "Please visit the docs for help with the new setup: " + f"{SCOPE_DOC_URL}" + ) + raise ValueError( + "Google Drive connector received old input parameters. " + "Please visit the docs for help with the new setup: " + f"{SCOPE_DOC_URL}" + ) + self.batch_size = batch_size self.include_shared_drives = include_shared_drives - self.shared_drive_ids = _get_string_list_from_comma_separated_string( - shared_drive_ids - ) + shared_drive_urls = _extract_str_list_from_comma_str(shared_drive_urls) + self.shared_drive_ids = _extract_ids_from_urls(shared_drive_urls) self.include_my_drives = include_my_drives - self.my_drive_emails = _get_string_list_from_comma_separated_string( - my_drive_emails - ) + self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails) - self.folder_ids = _get_string_list_from_comma_separated_string(folder_ids) + shared_folder_urls = _extract_str_list_from_comma_str(shared_folder_urls) + self.shared_folder_ids = _extract_ids_from_urls(shared_folder_urls) self.primary_admin_email: str | None = None self.google_domain: str | None = None self.creds: OAuthCredentials | ServiceAccountCredentials | None = None - self._TRAVERSED_PARENT_IDS: set[str] = set() - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self.google_domain = primary_admin_email.split("@")[1] @@ -113,26 +139,27 @@ def _fetch_drive_items( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: - admin_drive_service = self.get_google_resource() + primary_drive_service = self.get_google_resource() if self.include_shared_drives: - shared_drive_ids = self.shared_drive_ids - if not shared_drive_ids: + shared_drive_urls = self.shared_drive_ids + if not shared_drive_urls: # if no parent ids are specified, get all shared drives using the admin account for drive in execute_paginated_retrieval( - retrieval_function=admin_drive_service.drives().list, + retrieval_function=primary_drive_service.drives().list, list_key="drives", useDomainAdminAccess=True, fields="drives(id)", ): - shared_drive_ids.append(drive["id"]) + shared_drive_urls.append(drive["id"]) # crawl all the shared parent ids for files - for shared_drive_id in shared_drive_ids: + for shared_drive_id in shared_drive_urls: for file in get_files_in_shared_drive( - service=admin_drive_service, + service=primary_drive_service, drive_id=shared_drive_id, is_slim=is_slim, + cache_folders=bool(self.folder_ids), start=start, end=end, ): @@ -141,7 +168,7 @@ def _fetch_drive_items( if self.folder_ids: for folder_id in self.folder_ids: yield from crawl_folders_for_files( - service=admin_drive_service, + service=primary_drive_service, parent_id=folder_id, personal_drive=False, start=start, @@ -150,17 +177,16 @@ def _fetch_drive_items( # get all personal docs from each users' personal drive if self.include_my_drives: - all_user_emails: set[str] = set(self.my_drive_emails or []) + if isinstance(self.creds, ServiceAccountCredentials): + all_user_emails = self.my_drive_emails or [] - # If using service account and no emails specified, fetch all users - if not all_user_emails and isinstance( - self.creds, ServiceAccountCredentials - ): - all_user_emails = set(self._get_all_user_emails()) + # If using service account and no emails specified, fetch all users + if not all_user_emails: + all_user_emails = self._get_all_user_emails() - # Always include the primary admin email - if self.primary_admin_email: - all_user_emails.add(self.primary_admin_email) + else: + # If using OAuth, only fetch the primary admin email + all_user_emails = self.primary_admin_email or [] for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 1b087f28c5a..cb19151c12d 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -3,7 +3,7 @@ DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" -SLIM_FILE_FIELDS = "nextPageToken, files(id, permissions(emailAddress, type), permissionIds, webViewLink)" +SLIM_FILE_FIELDS = "nextPageToken, files(id, mimeType, permissions(emailAddress, type), permissionIds, webViewLink)" FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" USER_FIELDS = "nextPageToken, users(primaryEmail)" diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index db296e49c38..cb911a0cc0c 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -122,9 +122,27 @@ def get_files_in_shared_drive( service: Resource, drive_id: str, is_slim: bool = False, + cache_folders: bool = True, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: + # If we know we are going to folder crawl later, we can cache the folders here + if cache_folders: + # Get all folders being queried and add them to the traversed set + query = f"mimeType == '{DRIVE_FOLDER_TYPE}'" + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="drive", + driveId=drive_id, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + fields="nextPageToken, files(id)", + q=query, + ): + _TRAVERSED_PARENT_IDS.add(file["id"]) + + # Get all files in the shared drive query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index dc7f75e0632..a6ac93441e6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -104,6 +104,7 @@ const RenderField: FC = ({ type={field.type} label={label} name={field.name} + isTextArea={true} /> )} diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index f8ec588a325..03a73fe23e6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -10,6 +10,7 @@ import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants"; import Cookies from "js-cookie"; import { TextFormField } from "@/components/admin/connectors/Field"; import { Form, Formik } from "formik"; +import { User } from "@/lib/types"; import { Button as TremorButton } from "@tremor/react"; import { Credential, @@ -157,6 +158,7 @@ export const DriveJsonUploadSection = ({ isAdmin, }: DriveJsonUploadSectionProps) => { const { mutate } = useSWRConfig(); + const router = useRouter(); if (serviceAccountCredentialData?.service_account_email) { return ( @@ -190,6 +192,7 @@ export const DriveJsonUploadSection = ({ message: "Successfully deleted service account key", type: "success", }); + router.refresh(); } else { const errorMsg = await response.text(); setPopup({ @@ -307,6 +310,7 @@ interface DriveCredentialSectionProps { setPopup: (popupSpec: PopupSpec | null) => void; refreshCredentials: () => void; connectorExists: boolean; + user: User | null; } export const DriveAuthSection = ({ @@ -317,6 +321,7 @@ export const DriveAuthSection = ({ setPopup, refreshCredentials, connectorExists, + user, }: DriveCredentialSectionProps) => { const router = useRouter(); @@ -361,13 +366,13 @@ export const DriveAuthSection = ({ impersonate.

- Ideally, this account should be the owner of the Google Organization - that owns the Google Drive you want to index. + Ideally, this account should be an owner/admin of the Google + Organization that owns the Google Drive(s) you want to index.

{ - const { isLoadingUser, isAdmin } = useUser(); + const { isLoadingUser, isAdmin, user } = useUser(); const { data: appCredentialData, @@ -145,6 +145,7 @@ const GDriveMain = ({}: {}) => { appCredentialData={appCredentialData} serviceAccountKeyData={serviceAccountKeyData} connectorExists={googleDriveConnectorIndexingStatuses.length > 0} + user={user} /> )} diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index b668d84323a..eef0d5f5ca9 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -212,8 +212,9 @@ export const connectorConfigs: Record< values: [ { type: "checkbox", - description: "Include shared drives?", - label: "Include Shared Drives", + label: "Include shared drives?", + description: + "This will allow Danswer to index everything in your shared drives.", name: "include_shared_drives", optional: true, default: true, @@ -223,7 +224,7 @@ export const connectorConfigs: Record< description: "Enter a comma separated list of the IDs of the shared drives to index. Leave blank to index all shared drives.", label: "Shared Drive IDs", - name: "shared_drive_ids", + name: "shared_drive_urls", visibleCondition: (values) => values.include_shared_drives, optional: true, }, @@ -235,8 +236,8 @@ export const connectorConfigs: Record< : "Include Everyone's My Drive?", description: (currentCredential) => currentCredential?.credential_json?.google_drive_tokens - ? "This will let Danswer index everything in your My Drive." - : "This will let Danswer index everything in everyone's My Drives.", + ? "This will allow Danswer to index everything in your My Drive." + : "This will allow Danswer to index everything in everyone's My Drives.", name: "include_my_drives", optional: true, default: true, From 7f0c251e18acc9855bb65a01331aa62a265497bc Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 14:56:08 -0700 Subject: [PATCH 16/23] string fix --- web/src/lib/connectors/connectors.tsx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index eef0d5f5ca9..333642de820 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -222,8 +222,8 @@ export const connectorConfigs: Record< { type: "text", description: - "Enter a comma separated list of the IDs of the shared drives to index. Leave blank to index all shared drives.", - label: "Shared Drive IDs", + "Enter a comma separated list of the URLs of the shared drives to index. Leave blank to index all shared drives.", + label: "Shared Drive URLs", name: "shared_drive_urls", visibleCondition: (values) => values.include_shared_drives, optional: true, @@ -258,9 +258,9 @@ export const connectorConfigs: Record< { type: "text", description: - "Enter a comma separated list of the IDs of the folders to index:", - label: "Folder IDs", - name: "folder_ids", + "Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed.", + label: "Folder URLs", + name: "shared_folder_urls", optional: true, }, ], From f85a32bf83cfea23f16d4ce570843bb3bb1ff4f8 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 16:23:17 -0700 Subject: [PATCH 17/23] minor fixes and cleanup --- .../connectors/google_drive/connector.py | 10 ++++++---- .../connectors/google_drive/constants.py | 1 + .../connectors/google_drive/file_retrieval.py | 2 +- .../admin/connectors/AccessTypeForm.tsx | 18 +++++++++++++++++- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 0585904ec60..31b8b1cabb6 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -37,7 +37,9 @@ def _extract_str_list_from_comma_str(string: str | None) -> list[str]: - return string.split(",") if string else [] + if not string: + return [] + return [s.strip() for s in string.split(",") if s.strip()] def _extract_ids_from_urls(urls: list[str]) -> list[str]: @@ -159,14 +161,14 @@ def _fetch_drive_items( service=primary_drive_service, drive_id=shared_drive_id, is_slim=is_slim, - cache_folders=bool(self.folder_ids), + cache_folders=bool(self.shared_folder_ids), start=start, end=end, ): yield file - if self.folder_ids: - for folder_id in self.folder_ids: + if self.shared_folder_ids: + for folder_id in self.shared_folder_ids: yield from crawl_folders_for_files( service=primary_drive_service, parent_id=folder_id, diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index cb19151c12d..66b4ebf59b9 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -1,6 +1,7 @@ UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" +DRIVE_FILE_TYPE = "application/vnd.google-apps.file" FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" SLIM_FILE_FIELDS = "nextPageToken, files(id, mimeType, permissions(emailAddress, type), permissionIds, webViewLink)" diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index cb911a0cc0c..d99574ea8fd 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -129,7 +129,7 @@ def get_files_in_shared_drive( # If we know we are going to folder crawl later, we can cache the folders here if cache_folders: # Get all folders being queried and add them to the traversed set - query = f"mimeType == '{DRIVE_FOLDER_TYPE}'" + query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx index 1d3e93774d1..bbf7a2f4501 100644 --- a/web/src/components/admin/connectors/AccessTypeForm.tsx +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -9,6 +9,7 @@ import { useUser } from "@/components/user/UserProvider"; import { useField } from "formik"; import { AutoSyncOptions } from "./AutoSyncOptions"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { useEffect } from "react"; function isValidAutoSyncSource( value: ConfigurableSources @@ -28,6 +29,21 @@ export function AccessTypeForm({ const isAutoSyncSupported = isValidAutoSyncSource(connector); const { isLoadingUser, isAdmin } = useUser(); + useEffect(() => { + if (!isPaidEnterpriseEnabled) { + access_type_helpers.setValue("public"); + } else if (isAutoSyncSupported) { + access_type_helpers.setValue("sync"); + } else { + access_type_helpers.setValue("private"); + } + }, [ + isAutoSyncSupported, + isAdmin, + isPaidEnterpriseEnabled, + access_type_helpers, + ]); + const options = [ { name: "Private", @@ -46,7 +62,7 @@ export function AccessTypeForm({ }); } - if (isAutoSyncSupported && isAdmin) { + if (isAutoSyncSupported && isAdmin && isPaidEnterpriseEnabled) { options.push({ name: "Auto Sync Permissions", value: "sync", From f69df8a7b82b7c528c74c09ed4a7699d37fd4ba3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 30 Oct 2024 16:26:07 -0700 Subject: [PATCH 18/23] spacing cleanup --- .../admin/connectors/AccessTypeForm.tsx | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx index bbf7a2f4501..8993e28cdb3 100644 --- a/web/src/components/admin/connectors/AccessTypeForm.tsx +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -75,12 +75,13 @@ export function AccessTypeForm({ <> {isPaidEnterpriseEnabled && isAdmin && ( <> -
+
+

+ Control who has access to the documents indexed by this connector. +

-

- Control who has access to the documents indexed by this connector. -

+ {access_type.value === "sync" && isAutoSyncSupported && ( -
- -
+ )} )} From 8f4f214b4427f230a9c1233a4ba61835872aa9c7 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:33:33 -0700 Subject: [PATCH 19/23] Update connector.py --- backend/danswer/server/documents/connector.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 60e973c36f3..b0866a826c1 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -358,20 +358,6 @@ def check_drive_tokens( return AuthStatus(authenticated=True) -# @router.get("/admin/connector/google-drive/authorize/{credential_id}") -# def admin_google_drive_auth( -# response: Response, credential_id: str, _: User = Depends(current_admin_user) -# ) -> AuthUrl: -# # set a cookie that we can read in the callback (used for `verify_csrf`) -# response.set_cookie( -# key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME, -# value=credential_id, -# httponly=True, -# max_age=600, -# ) -# return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id))) - - @router.post("/admin/connector/file/upload") def upload_files( files: list[UploadFile], From 5dd71ac48a7691df8f4bc14ad9930578db902485 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 31 Oct 2024 15:42:28 -0700 Subject: [PATCH 20/23] everything done --- .../connectors/google_drive/connector.py | 36 ++- .../connectors/google_drive/constants.py | 10 +- .../connectors/google_drive/file_retrieval.py | 24 +- .../google_drive/doc_sync.py | 74 +++--- .../daily/connectors/google_drive/conftest.py | 98 +++++++ .../connectors/google_drive/file_generator.py | 11 + .../daily/connectors/google_drive/helpers.py | 172 +++++++++++++ .../google_drive/test_google_drive_oauth.py | 231 +++++++++++++++++ .../test_google_drive_service_acct.py | 241 ++++++++++++++++++ .../test_google_drive_slim_docs.py | 156 ++++++++++++ web/src/lib/connectors/connectors.tsx | 2 +- 11 files changed, 1001 insertions(+), 54 deletions(-) create mode 100644 backend/tests/daily/connectors/google_drive/conftest.py create mode 100644 backend/tests/daily/connectors/google_drive/file_generator.py create mode 100644 backend/tests/daily/connectors/google_drive/helpers.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 31b8b1cabb6..4ddd51f749f 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -81,23 +81,38 @@ def __init__( f"{SCOPE_DOC_URL}" ) + if ( + not include_shared_drives + and not include_my_drives + and not shared_folder_urls + ): + raise ValueError( + "At least one of include_shared_drives, include_my_drives," + " or shared_folder_urls must be true" + ) + self.batch_size = batch_size self.include_shared_drives = include_shared_drives - shared_drive_urls = _extract_str_list_from_comma_str(shared_drive_urls) - self.shared_drive_ids = _extract_ids_from_urls(shared_drive_urls) + shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls) + self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list) self.include_my_drives = include_my_drives self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails) - shared_folder_urls = _extract_str_list_from_comma_str(shared_folder_urls) - self.shared_folder_ids = _extract_ids_from_urls(shared_folder_urls) + shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) + self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list) self.primary_admin_email: str | None = None self.google_domain: str | None = None self.creds: OAuthCredentials | ServiceAccountCredentials | None = None + self._TRAVERSED_PARENT_IDS: set[str] = set() + + def _update_traversed_parent_ids(self, folder_id: str) -> None: + self._TRAVERSED_PARENT_IDS.add(folder_id) + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self.google_domain = primary_admin_email.split("@")[1] @@ -155,28 +170,33 @@ def _fetch_drive_items( ): shared_drive_urls.append(drive["id"]) - # crawl all the shared parent ids for files + # For each shared drive, retrieve all files for shared_drive_id in shared_drive_urls: for file in get_files_in_shared_drive( service=primary_drive_service, drive_id=shared_drive_id, is_slim=is_slim, cache_folders=bool(self.shared_folder_ids), + update_traversed_ids_func=self._update_traversed_parent_ids, start=start, end=end, ): yield file if self.shared_folder_ids: + # Crawl all the shared parent ids for files for folder_id in self.shared_folder_ids: yield from crawl_folders_for_files( service=primary_drive_service, parent_id=folder_id, personal_drive=False, + traversed_parent_ids=self._TRAVERSED_PARENT_IDS, + update_traversed_ids_func=self._update_traversed_parent_ids, start=start, end=end, ) + all_user_emails = [] # get all personal docs from each users' personal drive if self.include_my_drives: if isinstance(self.creds, ServiceAccountCredentials): @@ -186,9 +206,9 @@ def _fetch_drive_items( if not all_user_emails: all_user_emails = self._get_all_user_emails() - else: + elif self.primary_admin_email: # If using OAuth, only fetch the primary admin email - all_user_emails = self.primary_admin_email or [] + all_user_emails = [self.primary_admin_email] for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") @@ -262,6 +282,8 @@ def _extract_slim_docs_from_google_drive( "doc_id": file.get("id"), "permissions": file.get("permissions", []), "permission_ids": file.get("permissionIds", []), + "name": file.get("name"), + "owner_email": file.get("owners", [{}])[0].get("emailAddress"), }, ) ) diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 66b4ebf59b9..848a21fffe6 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -3,8 +3,14 @@ DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" DRIVE_FILE_TYPE = "application/vnd.google-apps.file" -FILE_FIELDS = "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, shortcutDetails, owners)" -SLIM_FILE_FIELDS = "nextPageToken, files(id, mimeType, permissions(emailAddress, type), permissionIds, webViewLink)" +FILE_FIELDS = ( + "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, " + "shortcutDetails, owners(emailAddress))" +) +SLIM_FILE_FIELDS = ( + "nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), " + "permissionIds, webViewLink, owners(emailAddress))" +) FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" USER_FIELDS = "nextPageToken, users(primaryEmail)" diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index d99574ea8fd..ea4e7d49466 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from collections.abc import Iterator from datetime import datetime @@ -37,6 +38,7 @@ def _get_folders_in_parent( ) -> Iterator[GoogleDriveFileType]: # Follow shortcuts to folders query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" + query += " and trashed = false" if parent_id: query += f" and '{parent_id}' in parents" @@ -62,6 +64,7 @@ def _get_files_in_parent( is_slim: bool = False, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" + query += " and trashed = false" query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( @@ -76,24 +79,23 @@ def _get_files_in_parent( yield file -_TRAVERSED_PARENT_IDS: set[str] = set() - - def crawl_folders_for_files( service: Resource, parent_id: str, personal_drive: bool, + traversed_parent_ids: set[str], + update_traversed_ids_func: Callable[[str], None], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: """ - This one can start crawling from any folder. It is slower though. + This function starts crawling from any folder. It is slower though. """ - if parent_id in _TRAVERSED_PARENT_IDS: - logger.debug(f"Skipping subfolder since already traversed: {parent_id}") + if parent_id in traversed_parent_ids: + print(f"Skipping subfolder since already traversed: {parent_id}") return - _TRAVERSED_PARENT_IDS.add(parent_id) + update_traversed_ids_func(parent_id) yield from _get_files_in_parent( service=service, @@ -113,6 +115,8 @@ def crawl_folders_for_files( service=service, parent_id=subfolder["id"], personal_drive=personal_drive, + traversed_parent_ids=traversed_parent_ids, + update_traversed_ids_func=update_traversed_ids_func, start=start, end=end, ) @@ -123,6 +127,7 @@ def get_files_in_shared_drive( drive_id: str, is_slim: bool = False, cache_folders: bool = True, + update_traversed_ids_func: Callable[[str], None] = lambda _: None, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: @@ -130,6 +135,7 @@ def get_files_in_shared_drive( if cache_folders: # Get all folders being queried and add them to the traversed set query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" + query += " and trashed = false" for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", @@ -140,10 +146,11 @@ def get_files_in_shared_drive( fields="nextPageToken, files(id)", q=query, ): - _TRAVERSED_PARENT_IDS.add(file["id"]) + update_traversed_ids_func(file["id"]) # Get all files in the shared drive query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + query += " and trashed = false" query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( retrieval_function=service.files().list, @@ -166,6 +173,7 @@ def get_files_in_my_drive( end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners" + query += " and trashed = false" query += _generate_time_range_filter(start, end) for file in execute_paginated_retrieval( retrieval_function=service.files().list, diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index 2b2a83064a7..d1df0cb0846 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -2,12 +2,12 @@ from datetime import timezone from typing import Any -from googleapiclient.discovery import Resource # type: ignore from sqlalchemy.orm import Session from danswer.access.models import ExternalAccess from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit @@ -19,10 +19,10 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {} -def _get_slim_docs( +def _get_slim_doc_generator( cc_pair: ConnectorCredentialPair, google_drive_connector: GoogleDriveConnector, -) -> list[SlimDocument]: +) -> GenerateSlimDocumentOutput: current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() @@ -30,19 +30,20 @@ def _get_slim_docs( else 0.0 ) - doc_batch_generator = google_drive_connector.retrieve_all_slim_documents( + return google_drive_connector.retrieve_all_slim_documents( start=start_time, end=current_time.timestamp() ) - slim_docs = [doc for doc_batch in doc_batch_generator for doc in doc_batch] - - return slim_docs def _fetch_permissions_for_permission_ids( - admin_service: Resource, - doc_id: str, + google_drive_connector: GoogleDriveConnector, permission_ids: list[str], + permission_info: dict[str, Any], ) -> list[dict[str, Any]]: + doc_id = permission_info.get("doc_id") + if not permission_info or not doc_id: + return [] + # Check cache first for all permission IDs permissions = [ _PERMISSION_ID_PERMISSION_MAP[pid] @@ -54,9 +55,12 @@ def _fetch_permissions_for_permission_ids( if len(permissions) == len(permission_ids): return permissions + owner_email = permission_info.get("owner_email") + drive_service = google_drive_connector.get_google_resource(user_email=owner_email) + # Otherwise, fetch all permissions and update cache fetched_permissions = execute_paginated_retrieval( - retrieval_function=admin_service.permissions().list, + retrieval_function=drive_service.permissions().list, list_key="permissions", fileId=doc_id, fields="permissions(id, emailAddress, type, domain)", @@ -72,22 +76,19 @@ def _fetch_permissions_for_permission_ids( return permissions_for_doc_id -def _fetch_google_permissions_for_slim_doc( - db_session: Session, - admin_service: Resource, +def _get_permissions_from_slim_doc( + google_drive_connector: GoogleDriveConnector, slim_doc: SlimDocument, - company_domain: str | None, ) -> ExternalAccess: permission_info = slim_doc.perm_sync_data or {} permissions_list = permission_info.get("permissions", []) - doc_id = permission_info.get("doc_id") if not permissions_list: - if permission_ids := permission_info.get("permission_ids") and doc_id: + if permission_ids := permission_info.get("permission_ids"): permissions_list = _fetch_permissions_for_permission_ids( - admin_service=admin_service, - doc_id=doc_id, + google_drive_connector=google_drive_connector, permission_ids=permission_ids, + permission_info=permission_info, ) if not permissions_list: logger.warning(f"No permissions found for document {slim_doc.id}") @@ -97,6 +98,7 @@ def _fetch_google_permissions_for_slim_doc( is_public=False, ) + company_domain = google_drive_connector.google_domain user_emails: set[str] = set() group_emails: set[str] = set() public = False @@ -112,8 +114,6 @@ def _fetch_google_permissions_for_slim_doc( elif permission_type == "anyone": public = True - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) - return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_emails, @@ -136,19 +136,21 @@ def gdrive_doc_sync( ) google_drive_connector.load_credentials(cc_pair.credential.credential_json) - slim_docs = _get_slim_docs(cc_pair, google_drive_connector) - admin_service = google_drive_connector.get_google_resource() - - for slim_doc in slim_docs: - ext_access = _fetch_google_permissions_for_slim_doc( - db_session=db_session, - admin_service=admin_service, - slim_doc=slim_doc, - company_domain=google_drive_connector.google_domain, - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=slim_doc.id, - external_access=ext_access, - source_type=cc_pair.connector.source, - ) + slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) + + for slim_doc_batch in slim_doc_generator: + for slim_doc in slim_doc_batch: + ext_access = _get_permissions_from_slim_doc( + google_drive_connector=google_drive_connector, + slim_doc=slim_doc, + ) + batch_add_non_web_user_if_not_exists__no_commit( + db_session=db_session, + emails=list(ext_access.external_user_emails), + ) + upsert_document_external_perms__no_commit( + db_session=db_session, + doc_id=slim_doc.id, + external_access=ext_access, + source_type=cc_pair.connector.source, + ) diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py new file mode 100644 index 00000000000..0b516d0359c --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -0,0 +1,98 @@ +import json +import os +from collections.abc import Callable + +import pytest + +from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY +from danswer.connectors.google_drive.connector_auth import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, +) + + +def load_env_vars(env_file: str = ".env") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + env_path = os.path.join(current_dir, env_file) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + os.environ[key] = value.strip() + print("Successfully loaded environment variables") + except FileNotFoundError: + print(f"File {env_file} not found") + + +# Load environment variables at the module level +load_env_vars() + + +@pytest.fixture +def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: + def _connector_factory( + primary_admin_email: str = "admin@onyx-test.com", + include_shared_drives: bool = True, + shared_drive_urls: str | None = None, + include_my_drives: bool = True, + my_drive_emails: str | None = None, + shared_folder_urls: str | None = None, + ) -> GoogleDriveConnector: + connector = GoogleDriveConnector( + include_shared_drives=include_shared_drives, + shared_drive_urls=shared_drive_urls, + include_my_drives=include_my_drives, + my_drive_emails=my_drive_emails, + shared_folder_urls=shared_folder_urls, + ) + + json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] + refried_json_string = json.loads(json_string) + + credentials_json = { + DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + } + connector.load_credentials(credentials_json) + return connector + + return _connector_factory + + +@pytest.fixture +def google_drive_service_acct_connector_factory() -> ( + Callable[..., GoogleDriveConnector] +): + def _connector_factory( + primary_admin_email: str = "admin@onyx-test.com", + include_shared_drives: bool = True, + shared_drive_urls: str | None = None, + include_my_drives: bool = True, + my_drive_emails: str | None = None, + shared_folder_urls: str | None = None, + ) -> GoogleDriveConnector: + print("Creating GoogleDriveConnector with service account credentials") + connector = GoogleDriveConnector( + include_shared_drives=include_shared_drives, + shared_drive_urls=shared_drive_urls, + include_my_drives=include_my_drives, + my_drive_emails=my_drive_emails, + shared_folder_urls=shared_folder_urls, + ) + + json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] + refried_json_string = json.loads(json_string) + + # Load Service Account Credentials + connector.load_credentials( + { + KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: refried_json_string, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + } + ) + return connector + + return _connector_factory diff --git a/backend/tests/daily/connectors/google_drive/file_generator.py b/backend/tests/daily/connectors/google_drive/file_generator.py new file mode 100644 index 00000000000..14264900047 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/file_generator.py @@ -0,0 +1,11 @@ +import os + +documents_folder = os.path.expanduser("~/Documents") +for i in range(0, 60): + file_name = f"file_{i}.txt" + file_text = f"This is file {i}" + file_path = os.path.join(documents_folder, file_name) + if os.path.exists(file_path): + os.remove(file_path) + with open(file_path, "w") as file: + file.write(file_text) diff --git a/backend/tests/daily/connectors/google_drive/helpers.py b/backend/tests/daily/connectors/google_drive/helpers.py new file mode 100644 index 00000000000..c78dc446798 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/helpers.py @@ -0,0 +1,172 @@ +from collections.abc import Sequence + +from danswer.connectors.models import Document + +_ADMIN_FILE_RANGE = list(range(0, 5)) +_TEST_USER_1_FILE_RANGE = list(range(5, 10)) +_TEST_USER_2_FILE_RANGE = list(range(10, 15)) +_TEST_USER_3_FILE_RANGE = list(range(15, 20)) +_SHARED_DRIVE_1_FILE_RANGE = list(range(20, 25)) +_FOLDER_1_FILE_RANGE = list(range(25, 30)) +_FOLDER_1_1_FILE_RANGE = list(range(30, 35)) +_FOLDER_1_2_FILE_RANGE = list(range(35, 40)) +_SHARED_DRIVE_2_FILE_RANGE = list(range(40, 45)) +_FOLDER_2_FILE_RANGE = list(range(45, 50)) +_FOLDER_2_1_FILE_RANGE = list(range(50, 55)) +_FOLDER_2_2_FILE_RANGE = list(range(55, 60)) + +_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_RANGE +_PUBLIC_FILE_RANGE = list(range(55, 57)) +PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_RANGE + +_SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA" +# Group 1 is given access to this folder +_FOLDER_1_URL = ( + "https://drive.google.com/drive/folders/1d3I7U3vUZMDziF1OQqYRkB8Jp2s_GWUn" +) +_FOLDER_1_1_URL = ( + "https://drive.google.com/drive/folders/1aR33-zwzl_mnRAwH55GgtWTE-4A4yWWI" +) +_FOLDER_1_2_URL = ( + "https://drive.google.com/drive/folders/1IO0X55VhvLXf4mdxzHxuKf4wxrDBB6jq" +) +_SHARED_DRIVE_2_URL = "https://drive.google.com/drive/folders/0ABKspIh7P4f4Uk9PVA" +_FOLDER_2_URL = ( + "https://drive.google.com/drive/folders/1lNpCJ1teu8Se0louwL0oOHK9nEalskof" +) +_FOLDER_2_1_URL = ( + "https://drive.google.com/drive/folders/1XeDOMWwxTDiVr9Ig2gKum3Zq_Wivv6zY" +) +_FOLDER_2_2_URL = ( + "https://drive.google.com/drive/folders/1RKlsexA8h7NHvBAWRbU27MJotic7KXe3" +) + +_ADMIN_EMAIL = "admin@onyx-test.com" +_TEST_USER_1_EMAIL = "test_user_1@onyx-test.com" +_TEST_USER_2_EMAIL = "test_user_2@onyx-test.com" +_TEST_USER_3_EMAIL = "test_user_3@onyx-test.com" + +# All users have access to their own My Drive +DRIVE_MAPPING = { + "ADMIN": { + "range": _ADMIN_FILE_RANGE, + "email": _ADMIN_EMAIL, + # Admin has access to everything in shared + "access": ( + _ADMIN_FILE_RANGE + + _SHARED_DRIVE_1_FILE_RANGE + + _FOLDER_1_FILE_RANGE + + _FOLDER_1_1_FILE_RANGE + + _FOLDER_1_2_FILE_RANGE + + _SHARED_DRIVE_2_FILE_RANGE + + _FOLDER_2_FILE_RANGE + + _FOLDER_2_1_FILE_RANGE + + _FOLDER_2_2_FILE_RANGE + ), + }, + "TEST_USER_1": { + "range": _TEST_USER_1_FILE_RANGE, + "email": _TEST_USER_1_EMAIL, + # This user has access to drive 1 + # This user has redundant access to folder 1 because of group access + # This user has been given individual access to files in Admin's My Drive + "access": ( + _TEST_USER_1_FILE_RANGE + + _SHARED_DRIVE_1_FILE_RANGE + + _FOLDER_1_FILE_RANGE + + _FOLDER_1_1_FILE_RANGE + + _FOLDER_1_2_FILE_RANGE + + list(range(0, 2)) + ), + }, + "TEST_USER_2": { + "range": _TEST_USER_2_FILE_RANGE, + "email": _TEST_USER_2_EMAIL, + # Group 1 includes this user, giving access to folder 1 + # This user has also been given access to folder 2-1 + # This user has also been given individual access to files in folder 2 + "access": ( + _TEST_USER_2_FILE_RANGE + + _FOLDER_1_FILE_RANGE + + _FOLDER_1_1_FILE_RANGE + + _FOLDER_1_2_FILE_RANGE + + _FOLDER_2_1_FILE_RANGE + + list(range(45, 47)) + ), + }, + "TEST_USER_3": { + "range": _TEST_USER_3_FILE_RANGE, + "email": _TEST_USER_3_EMAIL, + # This user can only see his own files and public files + "access": (_TEST_USER_3_FILE_RANGE), + }, + "SHARED_DRIVE_1": {"range": _SHARED_DRIVE_1_FILE_RANGE, "url": _SHARED_DRIVE_1_URL}, + "FOLDER_1": {"range": _FOLDER_1_FILE_RANGE, "url": _FOLDER_1_URL}, + "FOLDER_1_1": {"range": _FOLDER_1_1_FILE_RANGE, "url": _FOLDER_1_1_URL}, + "FOLDER_1_2": {"range": _FOLDER_1_2_FILE_RANGE, "url": _FOLDER_1_2_URL}, + "SHARED_DRIVE_2": {"range": _SHARED_DRIVE_2_FILE_RANGE, "url": _SHARED_DRIVE_2_URL}, + "FOLDER_2": {"range": _FOLDER_2_FILE_RANGE, "url": _FOLDER_2_URL}, + "FOLDER_2_1": {"range": _FOLDER_2_1_FILE_RANGE, "url": _FOLDER_2_1_URL}, + "FOLDER_2_2": {"range": _FOLDER_2_2_FILE_RANGE, "url": _FOLDER_2_2_URL}, +} + + +file_name_template = "file_{}.txt" +file_text_template = "This is file {}" + + +def get_expected_file_names_and_texts( + expected_file_range: list[int], +) -> tuple[set[str], set[str]]: + file_names = [file_name_template.format(i) for i in expected_file_range] + file_texts = [file_text_template.format(i) for i in expected_file_range] + return set(file_names), set(file_texts) + + +def validate_file_names_and_texts( + docs: list[Document], expected_file_range: list[int] +) -> None: + expected_file_names, expected_file_texts = get_expected_file_names_and_texts( + expected_file_range + ) + + retrieved_file_names = set([doc.semantic_identifier for doc in docs]) + retrieved_texts = set([doc.sections[0].text for doc in docs]) + + # Check file names + if expected_file_names != retrieved_file_names: + print(expected_file_names) + print(retrieved_file_names) + print("Extra:") + print(retrieved_file_names - expected_file_names) + print("Missing:") + print(expected_file_names - retrieved_file_names) + assert ( + expected_file_names == retrieved_file_names + ), "Not all expected file names were found" + + # Check file texts + if expected_file_texts != retrieved_texts: + print(expected_file_texts) + print(retrieved_texts) + print("Extra:") + print(retrieved_texts - expected_file_texts) + print("Missing:") + print(expected_file_texts - retrieved_texts) + assert ( + expected_file_texts == retrieved_texts + ), "Not all expected file texts were found" + + +def flatten_file_ranges(file_ranges: list[Sequence[object]]) -> list[int]: + expected_file_range = [] + for range in file_ranges: + if isinstance(range, list): + for i in range: + if isinstance(i, int): + expected_file_range.append(i) + else: + raise ValueError(f"Expected int, got {type(i)}") + else: + raise ValueError(f"Expected list, got {type(range)}") + return expected_file_range diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py new file mode 100644 index 00000000000..70d5ebbee36 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py @@ -0,0 +1,231 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING +from tests.daily.connectors.google_drive.helpers import flatten_file_ranges +from tests.daily.connectors.google_drive.helpers import validate_file_names_and_texts + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_all( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_all") + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should get everything + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_shared_drives_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_shared_drives_only") + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=False, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should only get shared drives + expected_file_ranges = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_my_drives_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_my_drives_only") + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=True, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should only get everyone's My Drives + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_drive_one_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_drive_one_only") + urls = [DRIVE_MAPPING["SHARED_DRIVE_1"]["url"]] + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=False, + shared_drive_urls=",".join([str(url) for url in urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # We ignore shared_drive_urls if include_shared_drives is False + expected_file_ranges = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folder_and_shared_drive( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folder_and_shared_drive") + drive_urls = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["url"], + ] + folder_urls = [DRIVE_MAPPING["FOLDER_2"]["url"]] + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=True, + shared_drive_urls=",".join([str(url) for url in drive_urls]), + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folders_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folders_only") + folder_urls = [ + DRIVE_MAPPING["FOLDER_1_1"]["url"], + DRIVE_MAPPING["FOLDER_1_2"]["url"], + DRIVE_MAPPING["FOLDER_2_1"]["url"], + DRIVE_MAPPING["FOLDER_2_2"]["url"], + ] + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + expected_file_ranges = [ + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_specific_emails( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_specific_emails") + my_drive_emails = [ + DRIVE_MAPPING["TEST_USER_1"]["email"], + DRIVE_MAPPING["TEST_USER_3"]["email"], + ] + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=True, + my_drive_emails=",".join([str(email) for email in my_drive_emails]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # No matter who is specified, when using oauth, if include_my_drives is True, + # we will get all the files from the admin's My Drive + expected_file_ranges = [DRIVE_MAPPING["ADMIN"]["range"]] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py new file mode 100644 index 00000000000..ecdd5b2e149 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py @@ -0,0 +1,241 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING +from tests.daily.connectors.google_drive.helpers import flatten_file_ranges +from tests.daily.connectors.google_drive.helpers import validate_file_names_and_texts + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_all( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_all") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should get everything + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["TEST_USER_1"]["range"], + DRIVE_MAPPING["TEST_USER_2"]["range"], + DRIVE_MAPPING["TEST_USER_3"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_shared_drives_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_shared_drives_only") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=False, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should only get shared drives + expected_file_ranges = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_my_drives_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_my_drives_only") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=True, + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should only get everyone's My Drives + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["TEST_USER_1"]["range"], + DRIVE_MAPPING["TEST_USER_2"]["range"], + DRIVE_MAPPING["TEST_USER_3"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_drive_one_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_drive_one_only") + urls = [DRIVE_MAPPING["SHARED_DRIVE_1"]["url"]] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=False, + shared_drive_urls=",".join([str(url) for url in urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # We ignore shared_drive_urls if include_shared_drives is False + expected_file_ranges = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folder_and_shared_drive( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folder_and_shared_drive") + drive_urls = [ + DRIVE_MAPPING["SHARED_DRIVE_1"]["url"], + ] + folder_urls = [DRIVE_MAPPING["FOLDER_2"]["url"]] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + shared_drive_urls=",".join([str(url) for url in drive_urls]), + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + # Should + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["TEST_USER_1"]["range"], + DRIVE_MAPPING["TEST_USER_2"]["range"], + DRIVE_MAPPING["TEST_USER_3"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folders_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folders_only") + folder_urls = [ + DRIVE_MAPPING["FOLDER_1_1"]["url"], + DRIVE_MAPPING["FOLDER_1_2"]["url"], + DRIVE_MAPPING["FOLDER_2_1"]["url"], + DRIVE_MAPPING["FOLDER_2_2"]["url"], + ] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + expected_file_ranges = [ + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_specific_emails( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_specific_emails") + my_drive_emails = [ + DRIVE_MAPPING["TEST_USER_1"]["email"], + DRIVE_MAPPING["TEST_USER_3"]["email"], + ] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=True, + my_drive_emails=",".join([str(email) for email in my_drive_emails]), + ) + docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + docs.extend(doc_batch) + + expected_file_ranges = [ + DRIVE_MAPPING["TEST_USER_1"]["range"], + DRIVE_MAPPING["TEST_USER_3"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + validate_file_names_and_texts(docs, expected_file_range) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py new file mode 100644 index 00000000000..e762427ec06 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py @@ -0,0 +1,156 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.access.models import ExternalAccess +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from ee.danswer.external_permissions.google_drive.doc_sync import ( + _get_permissions_from_slim_doc, +) +from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING +from tests.daily.connectors.google_drive.helpers import flatten_file_ranges +from tests.daily.connectors.google_drive.helpers import ( + get_expected_file_names_and_texts, +) +from tests.daily.connectors.google_drive.helpers import PUBLIC_RANGE + + +def get_keys_available_to_user_from_access_map( + user_email: str, + group_map: dict[str, list[str]], + access_map: dict[str, ExternalAccess], +) -> list[str]: + """ + Extracts the names of the files available to the user from the access map + through their own email or group memberships or public access + """ + group_emails_for_user = [] + for group_email, user_in_group_email_list in group_map.items(): + if user_email in user_in_group_email_list: + group_emails_for_user.append(group_email) + + accessible_file_names_for_user = [] + for file_name, external_access in access_map.items(): + if external_access.is_public: + accessible_file_names_for_user.append(file_name) + elif user_email in external_access.external_user_emails: + accessible_file_names_for_user.append(file_name) + elif any( + group_email in external_access.external_user_group_ids + for group_email in group_emails_for_user + ): + accessible_file_names_for_user.append(file_name) + return accessible_file_names_for_user + + +def check_access_for_user( + user_dict: dict, + group_map: dict[str, list[str]], + retrieved_access_map: dict[str, ExternalAccess], +) -> None: + """ + compares the expected access range of the user to the keys available to the user + retrieved from the source + """ + retrieved_keys_available_to_user = get_keys_available_to_user_from_access_map( + user_dict["email"], group_map, retrieved_access_map + ) + + expected_access_range = list(set(user_dict["access"] + PUBLIC_RANGE)) + + expected_file_names, _ = get_expected_file_names_and_texts(expected_access_range) + + retrieved_file_names = set(retrieved_keys_available_to_user) + if expected_file_names != retrieved_file_names: + print(user_dict["email"]) + print(expected_file_names) + print(retrieved_file_names) + + assert expected_file_names == retrieved_file_names + + +# This function is supposed to map to the group_sync.py file for the google drive connector +# TODO: Call it directly +def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]: + admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") + + group_map: dict[str, list[str]] = {} + for group in execute_paginated_retrieval( + admin_service.groups().list, + list_key="groups", + domain=google_drive_connector.google_domain, + fields="groups(email)", + ): + # The id is the group email + group_email = group["email"] + + # Gather group member emails + group_member_emails: list[str] = [] + for member in execute_paginated_retrieval( + admin_service.members().list, + list_key="members", + groupKey=group_email, + fields="members(email)", + ): + group_member_emails.append(member["email"]) + group_map[group_email] = group_member_emails + return group_map + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_all_permissions( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + google_drive_connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + + access_map: dict[str, ExternalAccess] = {} + for slim_doc_batch in google_drive_connector.retrieve_all_slim_documents( + 0, time.time() + ): + for slim_doc in slim_doc_batch: + access_map[ + (slim_doc.perm_sync_data or {})["name"] + ] = _get_permissions_from_slim_doc( + google_drive_connector=google_drive_connector, + slim_doc=slim_doc, + ) + + for file_name, external_access in access_map.items(): + print(file_name, external_access) + + expected_file_ranges = [ + DRIVE_MAPPING["ADMIN"]["range"], + DRIVE_MAPPING["TEST_USER_1"]["range"], + DRIVE_MAPPING["TEST_USER_2"]["range"], + DRIVE_MAPPING["TEST_USER_3"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], + DRIVE_MAPPING["FOLDER_1"]["range"], + DRIVE_MAPPING["FOLDER_1_1"]["range"], + DRIVE_MAPPING["FOLDER_1_2"]["range"], + DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], + DRIVE_MAPPING["FOLDER_2"]["range"], + DRIVE_MAPPING["FOLDER_2_1"]["range"], + DRIVE_MAPPING["FOLDER_2_2"]["range"], + ] + expected_file_range = flatten_file_ranges(expected_file_ranges) + + # Should get everything + assert len(access_map) == len(expected_file_range) + + group_map = get_group_map(google_drive_connector) + + print("groups:\n", group_map) + + check_access_for_user(DRIVE_MAPPING["ADMIN"], group_map, access_map) + check_access_for_user(DRIVE_MAPPING["TEST_USER_1"], group_map, access_map) + check_access_for_user(DRIVE_MAPPING["TEST_USER_2"], group_map, access_map) + check_access_for_user(DRIVE_MAPPING["TEST_USER_3"], group_map, access_map) diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 333642de820..6e7ef1ad97f 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -258,7 +258,7 @@ export const connectorConfigs: Record< { type: "text", description: - "Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed.", + "Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to the 'Include Shared Drives' and 'Shared Drive URLs' settings, so leave those blank if you only want to index the folders specified here.", label: "Folder URLs", name: "shared_folder_urls", optional: true, From c3193740305ff21f314a0dfa7676995ab5b1f00a Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 31 Oct 2024 16:06:06 -0700 Subject: [PATCH 21/23] testing! --- .github/workflows/pr-python-connector-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index 108012100b3..fa7df201b5e 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -18,6 +18,9 @@ env: # Jira JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} + # Google + GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }} + GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }} jobs: connectors-check: From 0cc5d49ceda11bd9ace859219140bfdc1e9b883f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 31 Oct 2024 17:02:37 -0700 Subject: [PATCH 22/23] Delete backend/tests/daily/connectors/google_drive/file_generator.py --- .../daily/connectors/google_drive/file_generator.py | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 backend/tests/daily/connectors/google_drive/file_generator.py diff --git a/backend/tests/daily/connectors/google_drive/file_generator.py b/backend/tests/daily/connectors/google_drive/file_generator.py deleted file mode 100644 index 14264900047..00000000000 --- a/backend/tests/daily/connectors/google_drive/file_generator.py +++ /dev/null @@ -1,11 +0,0 @@ -import os - -documents_folder = os.path.expanduser("~/Documents") -for i in range(0, 60): - file_name = f"file_{i}.txt" - file_text = f"This is file {i}" - file_path = os.path.join(documents_folder, file_name) - if os.path.exists(file_path): - os.remove(file_path) - with open(file_path, "w") as file: - file.write(file_text) From 1149ddb66236f793fe7319ec8dfe5cbff9005746 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 31 Oct 2024 19:02:45 -0700 Subject: [PATCH 23/23] cleaned up --- .../daily/connectors/google_drive/helpers.py | 242 +++++++++--------- .../google_drive/test_google_drive_oauth.py | 199 +++++++------- .../test_google_drive_service_acct.py | 214 +++++++++------- .../test_google_drive_slim_docs.py | 88 ++++--- 4 files changed, 392 insertions(+), 351 deletions(-) diff --git a/backend/tests/daily/connectors/google_drive/helpers.py b/backend/tests/daily/connectors/google_drive/helpers.py index c78dc446798..a1bc8feec38 100644 --- a/backend/tests/daily/connectors/google_drive/helpers.py +++ b/backend/tests/daily/connectors/google_drive/helpers.py @@ -2,22 +2,26 @@ from danswer.connectors.models import Document -_ADMIN_FILE_RANGE = list(range(0, 5)) -_TEST_USER_1_FILE_RANGE = list(range(5, 10)) -_TEST_USER_2_FILE_RANGE = list(range(10, 15)) -_TEST_USER_3_FILE_RANGE = list(range(15, 20)) -_SHARED_DRIVE_1_FILE_RANGE = list(range(20, 25)) -_FOLDER_1_FILE_RANGE = list(range(25, 30)) -_FOLDER_1_1_FILE_RANGE = list(range(30, 35)) -_FOLDER_1_2_FILE_RANGE = list(range(35, 40)) -_SHARED_DRIVE_2_FILE_RANGE = list(range(40, 45)) -_FOLDER_2_FILE_RANGE = list(range(45, 50)) -_FOLDER_2_1_FILE_RANGE = list(range(50, 55)) -_FOLDER_2_2_FILE_RANGE = list(range(55, 60)) - -_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_RANGE -_PUBLIC_FILE_RANGE = list(range(55, 57)) -PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_RANGE +ALL_FILES = list(range(0, 60)) +SHARED_DRIVE_FILES = list(range(20, 25)) + + +_ADMIN_FILE_IDS = list(range(0, 5)) +_TEST_USER_1_FILE_IDS = list(range(5, 10)) +_TEST_USER_2_FILE_IDS = list(range(10, 15)) +_TEST_USER_3_FILE_IDS = list(range(15, 20)) +_SHARED_DRIVE_1_FILE_IDS = list(range(20, 25)) +_FOLDER_1_FILE_IDS = list(range(25, 30)) +_FOLDER_1_1_FILE_IDS = list(range(30, 35)) +_FOLDER_1_2_FILE_IDS = list(range(35, 40)) +_SHARED_DRIVE_2_FILE_IDS = list(range(40, 45)) +_FOLDER_2_FILE_IDS = list(range(45, 50)) +_FOLDER_2_1_FILE_IDS = list(range(50, 55)) +_FOLDER_2_2_FILE_IDS = list(range(55, 60)) + +_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS +_PUBLIC_FILE_IDS = list(range(55, 57)) +PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_IDS _SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA" # Group 1 is given access to this folder @@ -46,68 +50,81 @@ _TEST_USER_2_EMAIL = "test_user_2@onyx-test.com" _TEST_USER_3_EMAIL = "test_user_3@onyx-test.com" -# All users have access to their own My Drive -DRIVE_MAPPING = { - "ADMIN": { - "range": _ADMIN_FILE_RANGE, - "email": _ADMIN_EMAIL, - # Admin has access to everything in shared - "access": ( - _ADMIN_FILE_RANGE - + _SHARED_DRIVE_1_FILE_RANGE - + _FOLDER_1_FILE_RANGE - + _FOLDER_1_1_FILE_RANGE - + _FOLDER_1_2_FILE_RANGE - + _SHARED_DRIVE_2_FILE_RANGE - + _FOLDER_2_FILE_RANGE - + _FOLDER_2_1_FILE_RANGE - + _FOLDER_2_2_FILE_RANGE - ), - }, - "TEST_USER_1": { - "range": _TEST_USER_1_FILE_RANGE, - "email": _TEST_USER_1_EMAIL, - # This user has access to drive 1 - # This user has redundant access to folder 1 because of group access - # This user has been given individual access to files in Admin's My Drive - "access": ( - _TEST_USER_1_FILE_RANGE - + _SHARED_DRIVE_1_FILE_RANGE - + _FOLDER_1_FILE_RANGE - + _FOLDER_1_1_FILE_RANGE - + _FOLDER_1_2_FILE_RANGE - + list(range(0, 2)) - ), - }, - "TEST_USER_2": { - "range": _TEST_USER_2_FILE_RANGE, - "email": _TEST_USER_2_EMAIL, - # Group 1 includes this user, giving access to folder 1 - # This user has also been given access to folder 2-1 - # This user has also been given individual access to files in folder 2 - "access": ( - _TEST_USER_2_FILE_RANGE - + _FOLDER_1_FILE_RANGE - + _FOLDER_1_1_FILE_RANGE - + _FOLDER_1_2_FILE_RANGE - + _FOLDER_2_1_FILE_RANGE - + list(range(45, 47)) - ), - }, - "TEST_USER_3": { - "range": _TEST_USER_3_FILE_RANGE, - "email": _TEST_USER_3_EMAIL, - # This user can only see his own files and public files - "access": (_TEST_USER_3_FILE_RANGE), - }, - "SHARED_DRIVE_1": {"range": _SHARED_DRIVE_1_FILE_RANGE, "url": _SHARED_DRIVE_1_URL}, - "FOLDER_1": {"range": _FOLDER_1_FILE_RANGE, "url": _FOLDER_1_URL}, - "FOLDER_1_1": {"range": _FOLDER_1_1_FILE_RANGE, "url": _FOLDER_1_1_URL}, - "FOLDER_1_2": {"range": _FOLDER_1_2_FILE_RANGE, "url": _FOLDER_1_2_URL}, - "SHARED_DRIVE_2": {"range": _SHARED_DRIVE_2_FILE_RANGE, "url": _SHARED_DRIVE_2_URL}, - "FOLDER_2": {"range": _FOLDER_2_FILE_RANGE, "url": _FOLDER_2_URL}, - "FOLDER_2_1": {"range": _FOLDER_2_1_FILE_RANGE, "url": _FOLDER_2_1_URL}, - "FOLDER_2_2": {"range": _FOLDER_2_2_FILE_RANGE, "url": _FOLDER_2_2_URL}, +# Dictionary for ranges +DRIVE_ID_MAPPING: dict[str, list[int]] = { + "ADMIN": _ADMIN_FILE_IDS, + "TEST_USER_1": _TEST_USER_1_FILE_IDS, + "TEST_USER_2": _TEST_USER_2_FILE_IDS, + "TEST_USER_3": _TEST_USER_3_FILE_IDS, + "SHARED_DRIVE_1": _SHARED_DRIVE_1_FILE_IDS, + "FOLDER_1": _FOLDER_1_FILE_IDS, + "FOLDER_1_1": _FOLDER_1_1_FILE_IDS, + "FOLDER_1_2": _FOLDER_1_2_FILE_IDS, + "SHARED_DRIVE_2": _SHARED_DRIVE_2_FILE_IDS, + "FOLDER_2": _FOLDER_2_FILE_IDS, + "FOLDER_2_1": _FOLDER_2_1_FILE_IDS, + "FOLDER_2_2": _FOLDER_2_2_FILE_IDS, +} + +# Dictionary for emails +EMAIL_MAPPING: dict[str, str] = { + "ADMIN": _ADMIN_EMAIL, + "TEST_USER_1": _TEST_USER_1_EMAIL, + "TEST_USER_2": _TEST_USER_2_EMAIL, + "TEST_USER_3": _TEST_USER_3_EMAIL, +} + +# Dictionary for URLs +URL_MAPPING: dict[str, str] = { + "SHARED_DRIVE_1": _SHARED_DRIVE_1_URL, + "FOLDER_1": _FOLDER_1_URL, + "FOLDER_1_1": _FOLDER_1_1_URL, + "FOLDER_1_2": _FOLDER_1_2_URL, + "SHARED_DRIVE_2": _SHARED_DRIVE_2_URL, + "FOLDER_2": _FOLDER_2_URL, + "FOLDER_2_1": _FOLDER_2_1_URL, + "FOLDER_2_2": _FOLDER_2_2_URL, +} + +# Dictionary for access permissions +# All users have access to their own My Drive as well as public files +ACCESS_MAPPING: dict[str, list[int]] = { + # Admin has access to everything in shared + "ADMIN": ( + _ADMIN_FILE_IDS + + _SHARED_DRIVE_1_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + _SHARED_DRIVE_2_FILE_IDS + + _FOLDER_2_FILE_IDS + + _FOLDER_2_1_FILE_IDS + + _FOLDER_2_2_FILE_IDS + ), + # This user has access to drive 1 + # This user has redundant access to folder 1 because of group access + # This user has been given individual access to files in Admin's My Drive + "TEST_USER_1": ( + _TEST_USER_1_FILE_IDS + + _SHARED_DRIVE_1_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + list(range(0, 2)) + ), + # Group 1 includes this user, giving access to folder 1 + # This user has also been given access to folder 2-1 + # This user has also been given individual access to files in folder 2 + "TEST_USER_2": ( + _TEST_USER_2_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + _FOLDER_2_1_FILE_IDS + + list(range(45, 47)) + ), + # This user can only see his own files and public files + "TEST_USER_3": _TEST_USER_3_FILE_IDS, } @@ -115,58 +132,33 @@ file_text_template = "This is file {}" -def get_expected_file_names_and_texts( - expected_file_range: list[int], -) -> tuple[set[str], set[str]]: - file_names = [file_name_template.format(i) for i in expected_file_range] - file_texts = [file_text_template.format(i) for i in expected_file_range] - return set(file_names), set(file_texts) +def print_discrepencies(expected: set[str], retrieved: set[str]) -> None: + if expected != retrieved: + print(expected) + print(retrieved) + print("Extra:") + print(retrieved - expected) + print("Missing:") + print(expected - retrieved) -def validate_file_names_and_texts( - docs: list[Document], expected_file_range: list[int] +def assert_retrieved_docs_match_expected( + retrieved_docs: list[Document], expected_file_ids: Sequence[int] ) -> None: - expected_file_names, expected_file_texts = get_expected_file_names_and_texts( - expected_file_range - ) + expected_file_names = { + file_name_template.format(file_id) for file_id in expected_file_ids + } + expected_file_texts = { + file_text_template.format(file_id) for file_id in expected_file_ids + } - retrieved_file_names = set([doc.semantic_identifier for doc in docs]) - retrieved_texts = set([doc.sections[0].text for doc in docs]) + retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs]) + retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs]) # Check file names - if expected_file_names != retrieved_file_names: - print(expected_file_names) - print(retrieved_file_names) - print("Extra:") - print(retrieved_file_names - expected_file_names) - print("Missing:") - print(expected_file_names - retrieved_file_names) - assert ( - expected_file_names == retrieved_file_names - ), "Not all expected file names were found" + print_discrepencies(expected_file_names, retrieved_file_names) + assert expected_file_names == retrieved_file_names # Check file texts - if expected_file_texts != retrieved_texts: - print(expected_file_texts) - print(retrieved_texts) - print("Extra:") - print(retrieved_texts - expected_file_texts) - print("Missing:") - print(expected_file_texts - retrieved_texts) - assert ( - expected_file_texts == retrieved_texts - ), "Not all expected file texts were found" - - -def flatten_file_ranges(file_ranges: list[Sequence[object]]) -> list[int]: - expected_file_range = [] - for range in file_ranges: - if isinstance(range, list): - for i in range: - if isinstance(i, int): - expected_file_range.append(i) - else: - raise ValueError(f"Expected int, got {type(i)}") - else: - raise ValueError(f"Expected list, got {type(range)}") - return expected_file_range + print_discrepencies(expected_file_texts, retrieved_texts) + assert expected_file_texts == retrieved_texts diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py index 70d5ebbee36..f39b15600b4 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py @@ -5,9 +5,12 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import Document -from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING -from tests.daily.connectors.google_drive.helpers import flatten_file_ranges -from tests.daily.connectors.google_drive.helpers import validate_file_names_and_texts +from tests.daily.connectors.google_drive.helpers import ( + assert_retrieved_docs_match_expected, +) +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import URL_MAPPING @patch( @@ -23,24 +26,26 @@ def test_include_all( include_shared_drives=True, include_my_drives=True, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) - - # Should get everything - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + retrieved_docs.extend(doc_batch) + + # Should get everything in shared and admin's My Drive with oauth + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -56,23 +61,25 @@ def test_include_shared_drives_only( include_shared_drives=True, include_my_drives=False, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should only get shared drives - expected_file_ranges = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -88,16 +95,16 @@ def test_include_my_drives_only( include_shared_drives=False, include_my_drives=True, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should only get everyone's My Drives - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = DRIVE_ID_MAPPING["ADMIN"] + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -109,25 +116,29 @@ def test_drive_one_only( google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_drive_one_only") - urls = [DRIVE_MAPPING["SHARED_DRIVE_1"]["url"]] + drive_urls = [ + URL_MAPPING["SHARED_DRIVE_1"], + ] connector = google_drive_oauth_connector_factory( include_shared_drives=True, include_my_drives=False, - shared_drive_urls=",".join([str(url) for url in urls]), + shared_drive_urls=",".join([str(url) for url in drive_urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # We ignore shared_drive_urls if include_shared_drives is False - expected_file_ranges = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -139,33 +150,33 @@ def test_folder_and_shared_drive( google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_folder_and_shared_drive") - drive_urls = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["url"], - ] - folder_urls = [DRIVE_MAPPING["FOLDER_2"]["url"]] + drive_urls = [URL_MAPPING["SHARED_DRIVE_1"]] + folder_urls = [URL_MAPPING["FOLDER_2"]] connector = google_drive_oauth_connector_factory( include_shared_drives=True, include_my_drives=True, shared_drive_urls=",".join([str(url) for url in drive_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -178,28 +189,30 @@ def test_folders_only( ) -> None: print("\n\nRunning test_folders_only") folder_urls = [ - DRIVE_MAPPING["FOLDER_1_1"]["url"], - DRIVE_MAPPING["FOLDER_1_2"]["url"], - DRIVE_MAPPING["FOLDER_2_1"]["url"], - DRIVE_MAPPING["FOLDER_2_2"]["url"], + URL_MAPPING["FOLDER_1_1"], + URL_MAPPING["FOLDER_1_2"], + URL_MAPPING["FOLDER_2_1"], + URL_MAPPING["FOLDER_2_2"], ] connector = google_drive_oauth_connector_factory( include_shared_drives=False, include_my_drives=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) - expected_file_ranges = [ - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -212,20 +225,22 @@ def test_specific_emails( ) -> None: print("\n\nRunning test_specific_emails") my_drive_emails = [ - DRIVE_MAPPING["TEST_USER_1"]["email"], - DRIVE_MAPPING["TEST_USER_3"]["email"], + EMAIL_MAPPING["TEST_USER_1"], + EMAIL_MAPPING["TEST_USER_3"], ] connector = google_drive_oauth_connector_factory( include_shared_drives=False, include_my_drives=True, my_drive_emails=",".join([str(email) for email in my_drive_emails]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # No matter who is specified, when using oauth, if include_my_drives is True, # we will get all the files from the admin's My Drive - expected_file_ranges = [DRIVE_MAPPING["ADMIN"]["range"]] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = DRIVE_ID_MAPPING["ADMIN"] + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py index ecdd5b2e149..b36a53b30f6 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py @@ -5,9 +5,12 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import Document -from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING -from tests.daily.connectors.google_drive.helpers import flatten_file_ranges -from tests.daily.connectors.google_drive.helpers import validate_file_names_and_texts +from tests.daily.connectors.google_drive.helpers import ( + assert_retrieved_docs_match_expected, +) +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import URL_MAPPING @patch( @@ -23,27 +26,29 @@ def test_include_all( include_shared_drives=True, include_my_drives=True, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should get everything - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["TEST_USER_1"]["range"], - DRIVE_MAPPING["TEST_USER_2"]["range"], - DRIVE_MAPPING["TEST_USER_3"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -59,23 +64,25 @@ def test_include_shared_drives_only( include_shared_drives=True, include_my_drives=False, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should only get shared drives - expected_file_ranges = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -91,19 +98,21 @@ def test_include_my_drives_only( include_shared_drives=False, include_my_drives=True, ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should only get everyone's My Drives - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["TEST_USER_1"]["range"], - DRIVE_MAPPING["TEST_USER_2"]["range"], - DRIVE_MAPPING["TEST_USER_3"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -115,25 +124,27 @@ def test_drive_one_only( google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_drive_one_only") - urls = [DRIVE_MAPPING["SHARED_DRIVE_1"]["url"]] + urls = [URL_MAPPING["SHARED_DRIVE_1"]] connector = google_drive_service_acct_connector_factory( include_shared_drives=True, include_my_drives=False, shared_drive_urls=",".join([str(url) for url in urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # We ignore shared_drive_urls if include_shared_drives is False - expected_file_ranges = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -146,35 +157,37 @@ def test_folder_and_shared_drive( ) -> None: print("\n\nRunning test_folder_and_shared_drive") drive_urls = [ - DRIVE_MAPPING["SHARED_DRIVE_1"]["url"], + URL_MAPPING["SHARED_DRIVE_1"], ] - folder_urls = [DRIVE_MAPPING["FOLDER_2"]["url"]] + folder_urls = [URL_MAPPING["FOLDER_2"]] connector = google_drive_service_acct_connector_factory( include_shared_drives=True, include_my_drives=True, shared_drive_urls=",".join([str(url) for url in drive_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) # Should - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["TEST_USER_1"]["range"], - DRIVE_MAPPING["TEST_USER_2"]["range"], - DRIVE_MAPPING["TEST_USER_3"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -187,28 +200,30 @@ def test_folders_only( ) -> None: print("\n\nRunning test_folders_only") folder_urls = [ - DRIVE_MAPPING["FOLDER_1_1"]["url"], - DRIVE_MAPPING["FOLDER_1_2"]["url"], - DRIVE_MAPPING["FOLDER_2_1"]["url"], - DRIVE_MAPPING["FOLDER_2_2"]["url"], + URL_MAPPING["FOLDER_1_1"], + URL_MAPPING["FOLDER_1_2"], + URL_MAPPING["FOLDER_2_1"], + URL_MAPPING["FOLDER_2_2"], ] connector = google_drive_service_acct_connector_factory( include_shared_drives=False, include_my_drives=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) - expected_file_ranges = [ - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) @patch( @@ -221,21 +236,22 @@ def test_specific_emails( ) -> None: print("\n\nRunning test_specific_emails") my_drive_emails = [ - DRIVE_MAPPING["TEST_USER_1"]["email"], - DRIVE_MAPPING["TEST_USER_3"]["email"], + EMAIL_MAPPING["TEST_USER_1"], + EMAIL_MAPPING["TEST_USER_3"], ] connector = google_drive_service_acct_connector_factory( include_shared_drives=False, include_my_drives=True, my_drive_emails=",".join([str(email) for email in my_drive_emails]), ) - docs: list[Document] = [] + retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): - docs.extend(doc_batch) + retrieved_docs.extend(doc_batch) - expected_file_ranges = [ - DRIVE_MAPPING["TEST_USER_1"]["range"], - DRIVE_MAPPING["TEST_USER_3"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) - validate_file_names_and_texts(docs, expected_file_range) + expected_file_ids = ( + DRIVE_ID_MAPPING["TEST_USER_1"] + DRIVE_ID_MAPPING["TEST_USER_3"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py index e762427ec06..e731c8b27ce 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py @@ -9,11 +9,11 @@ from ee.danswer.external_permissions.google_drive.doc_sync import ( _get_permissions_from_slim_doc, ) -from tests.daily.connectors.google_drive.helpers import DRIVE_MAPPING -from tests.daily.connectors.google_drive.helpers import flatten_file_ranges -from tests.daily.connectors.google_drive.helpers import ( - get_expected_file_names_and_texts, -) +from tests.daily.connectors.google_drive.helpers import ACCESS_MAPPING +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import file_name_template +from tests.daily.connectors.google_drive.helpers import print_discrepencies from tests.daily.connectors.google_drive.helpers import PUBLIC_RANGE @@ -45,8 +45,9 @@ def get_keys_available_to_user_from_access_map( return accessible_file_names_for_user -def check_access_for_user( - user_dict: dict, +def assert_correct_access_for_user( + user_email: str, + expected_access_ids: list[int], group_map: dict[str, list[str]], retrieved_access_map: dict[str, ExternalAccess], ) -> None: @@ -55,18 +56,15 @@ def check_access_for_user( retrieved from the source """ retrieved_keys_available_to_user = get_keys_available_to_user_from_access_map( - user_dict["email"], group_map, retrieved_access_map + user_email, group_map, retrieved_access_map ) + retrieved_file_names = set(retrieved_keys_available_to_user) - expected_access_range = list(set(user_dict["access"] + PUBLIC_RANGE)) - - expected_file_names, _ = get_expected_file_names_and_texts(expected_access_range) + # Combine public and user-specific access IDs + all_accessible_ids = expected_access_ids + PUBLIC_RANGE + expected_file_names = {file_name_template.format(i) for i in all_accessible_ids} - retrieved_file_names = set(retrieved_keys_available_to_user) - if expected_file_names != retrieved_file_names: - print(user_dict["email"]) - print(expected_file_names) - print(retrieved_file_names) + print_discrepencies(expected_file_names, retrieved_file_names) assert expected_file_names == retrieved_file_names @@ -127,21 +125,20 @@ def test_all_permissions( for file_name, external_access in access_map.items(): print(file_name, external_access) - expected_file_ranges = [ - DRIVE_MAPPING["ADMIN"]["range"], - DRIVE_MAPPING["TEST_USER_1"]["range"], - DRIVE_MAPPING["TEST_USER_2"]["range"], - DRIVE_MAPPING["TEST_USER_3"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_1"]["range"], - DRIVE_MAPPING["FOLDER_1"]["range"], - DRIVE_MAPPING["FOLDER_1_1"]["range"], - DRIVE_MAPPING["FOLDER_1_2"]["range"], - DRIVE_MAPPING["SHARED_DRIVE_2"]["range"], - DRIVE_MAPPING["FOLDER_2"]["range"], - DRIVE_MAPPING["FOLDER_2_1"]["range"], - DRIVE_MAPPING["FOLDER_2_2"]["range"], - ] - expected_file_range = flatten_file_ranges(expected_file_ranges) + expected_file_range = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) # Should get everything assert len(access_map) == len(expected_file_range) @@ -150,7 +147,28 @@ def test_all_permissions( print("groups:\n", group_map) - check_access_for_user(DRIVE_MAPPING["ADMIN"], group_map, access_map) - check_access_for_user(DRIVE_MAPPING["TEST_USER_1"], group_map, access_map) - check_access_for_user(DRIVE_MAPPING["TEST_USER_2"], group_map, access_map) - check_access_for_user(DRIVE_MAPPING["TEST_USER_3"], group_map, access_map) + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["ADMIN"], + expected_access_ids=ACCESS_MAPPING["ADMIN"], + group_map=group_map, + retrieved_access_map=access_map, + ) + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_1"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_1"], + group_map=group_map, + retrieved_access_map=access_map, + ) + + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_2"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_2"], + group_map=group_map, + retrieved_access_map=access_map, + ) + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_3"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_3"], + group_map=group_map, + retrieved_access_map=access_map, + )