Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions backend/onyx/connectors/google_drive/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Protocol
from urllib.parse import urlparse

from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
Expand Down Expand Up @@ -72,7 +73,9 @@
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once

BATCHES_PER_CHECKPOINT = 10
BATCHES_PER_CHECKPOINT = 1

DRIVE_BATCH_SIZE = 80


def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
Expand Down Expand Up @@ -184,8 +187,6 @@ def __init__(
"shared_folder_urls, or my_drive_emails"
)

self.batch_size = batch_size

specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
Expand Down Expand Up @@ -306,14 +307,14 @@ def _get_all_user_emails(self) -> list[str]:
return user_emails

def get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
return self._get_all_drives_for_user(self.primary_admin_email)

def _get_all_drives_for_user(self, user_email: str) -> set[str]:
drive_service = get_drive_service(self.creds, user_email)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
all_drive_ids: set[str] = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
retrieval_function=drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=is_service_account,
fields="drives(id),nextPageToken",
Expand Down Expand Up @@ -373,6 +374,10 @@ def record_drive_processing(drive_id: str) -> None:
if drive_id in self._retrieved_folder_and_drive_ids
else DriveIdStatus.AVAILABLE
)
logger.debug(
f"Drive id status: {len(drive_id_status)}, user email: {thread_id},"
f"processed drive ids: {len(completion.processed_drive_ids)}"
)
# wake up other threads waiting for work
cv.notify_all()

Expand Down Expand Up @@ -423,13 +428,15 @@ def _impersonate_user_for_retrieval(
curr_stage = checkpoint.completion_map[user_email]
resuming = True
if curr_stage.stage == DriveRetrievalStage.START:
logger.info(f"Setting stage to {DriveRetrievalStage.MY_DRIVE_FILES.value}")
curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES
resuming = False
drive_service = get_drive_service(self.creds, user_email)

# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
logger.debug(f"Getting root folder id for user {user_email}")
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
Expand All @@ -445,14 +452,29 @@ def _impersonate_user_for_retrieval(
curr_stage.stage = DriveRetrievalStage.DONE
return
raise

except RefreshError as e:
logger.warning(
f"User '{user_email}' could not refresh their token. Error: {e}"
)
# mark this user as done so we don't try to retrieve anything for them
# again
yield RetrievedDriveFile(
completion_stage=DriveRetrievalStage.DONE,
drive_file={},
user_email=user_email,
error=e,
)
curr_stage.stage = DriveRetrievalStage.DONE
return
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES:
if self.include_my_drives or user_email in self._requested_my_drive_emails:
logger.info(f"Getting all files in my drive as '{user_email}'")
logger.info(
f"Getting all files in my drive as '{user_email}. Resuming: {resuming}"
)

yield from add_retrieval_info(
get_all_files_in_my_drive_and_shared(
Expand Down Expand Up @@ -505,7 +527,7 @@ def _yield_from_drive(

for drive_id in concurrent_drive_itr(user_email):
logger.info(
f"Getting files in shared drive '{drive_id}' as '{user_email}'"
f"Getting files in shared drive '{drive_id}' as '{user_email}. Resuming: {resuming}"
)
curr_stage.completed_until = 0
curr_stage.current_folder_or_drive_id = drive_id
Expand Down Expand Up @@ -577,6 +599,14 @@ def _manage_service_account_retrieval(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[RetrievedDriveFile]:
"""
The current implementation of the service account retrieval does some
initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS
concurrent threads, each of which impersonates a different user and retrieves
files for that user. Technically, the actual work each thread does is "yield the
next file retrieved by the user", at which point it returns to the thread pool;
see parallel_yield for more details.
"""
if checkpoint.completion_stage == DriveRetrievalStage.START:
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS

Expand All @@ -602,6 +632,7 @@ def _manage_service_account_retrieval(
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
processed_drive_ids=set(),
)

# we've found all users and drives, now time to actually start
Expand All @@ -627,7 +658,7 @@ def _manage_service_account_retrieval(
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
Expand Down Expand Up @@ -871,6 +902,10 @@ def _checkpointed_retrieval(
return

for file in drive_files:
logger.debug(
f"Updating checkpoint for file: {file.drive_file.get('name')}. "
f"Seen: {file.drive_file.get('id') in checkpoint.all_retrieved_file_ids}"
)
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=datetime.fromisoformat(
Expand Down Expand Up @@ -1047,24 +1082,22 @@ def _yield_batch(
continue
files_batch.append(retrieved_file)

if len(files_batch) < self.batch_size:
if len(files_batch) < DRIVE_BATCH_SIZE:
continue

logger.info(
f"Yielding batch of {len(files_batch)} files; num seen doc ids: {len(checkpoint.all_retrieved_file_ids)}"
)
yield from _yield_batch(files_batch)
files_batch = []

if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
)
return # create a new checkpoint

logger.info(
f"Processing remaining files: {[file.drive_file.get('name') for file in files_batch]}"
)
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
)

except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
Expand All @@ -1083,6 +1116,10 @@ def load_from_checkpoint(
"Credentials missing, should not call this method before calling load_credentials"
)

logger.info(
f"Loading from checkpoint with completion stage: {checkpoint.completion_stage},"
f"num retrieved ids: {len(checkpoint.all_retrieved_file_ids)}"
)
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
try:
Expand Down
6 changes: 5 additions & 1 deletion backend/onyx/connectors/google_drive/doc_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,16 @@ def convert_drive_item_to_document(
doc_or_failure = _convert_drive_item_to_document(
creds, allow_images, size_threshold, retriever_email, file
)

# There are a variety of permissions-based errors that occasionally occur
# when retrieving files. Often when these occur, there is another user
# that can successfully retrieve the file, so we try the next user.
if (
doc_or_failure is None
or isinstance(doc_or_failure, Document)
or not (
isinstance(doc_or_failure.exception, HttpError)
and doc_or_failure.exception.status_code in [403, 404]
and doc_or_failure.exception.status_code in [401, 403, 404]
)
):
return doc_or_failure
Expand Down