Skip to content

Commit 0eab6ab

Browse files
authored
fix drive slowness (#4668)
* fix slowness * no more silent failing for users * nits * no silly info transfer
1 parent ee09cb9 commit 0eab6ab

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

backend/onyx/connectors/google_drive/connector.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Protocol
1111
from urllib.parse import urlparse
1212

13+
from google.auth.exceptions import RefreshError # type: ignore
1314
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
1415
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
1516
from googleapiclient.errors import HttpError # type: ignore
@@ -72,7 +73,9 @@
7273
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
7374
# All file retrievals could be batched and made at once
7475

75-
BATCHES_PER_CHECKPOINT = 10
76+
BATCHES_PER_CHECKPOINT = 1
77+
78+
DRIVE_BATCH_SIZE = 80
7679

7780

7881
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
@@ -184,8 +187,6 @@ def __init__(
184187
"shared_folder_urls, or my_drive_emails"
185188
)
186189

187-
self.batch_size = batch_size
188-
189190
specific_requests_made = False
190191
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
191192
specific_requests_made = True
@@ -306,14 +307,14 @@ def _get_all_user_emails(self) -> list[str]:
306307
return user_emails
307308

308309
def get_all_drive_ids(self) -> set[str]:
309-
primary_drive_service = get_drive_service(
310-
creds=self.creds,
311-
user_email=self.primary_admin_email,
312-
)
310+
return self._get_all_drives_for_user(self.primary_admin_email)
311+
312+
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
313+
drive_service = get_drive_service(self.creds, user_email)
313314
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
314-
all_drive_ids = set()
315+
all_drive_ids: set[str] = set()
315316
for drive in execute_paginated_retrieval(
316-
retrieval_function=primary_drive_service.drives().list,
317+
retrieval_function=drive_service.drives().list,
317318
list_key="drives",
318319
useDomainAdminAccess=is_service_account,
319320
fields="drives(id),nextPageToken",
@@ -373,6 +374,10 @@ def record_drive_processing(drive_id: str) -> None:
373374
if drive_id in self._retrieved_folder_and_drive_ids
374375
else DriveIdStatus.AVAILABLE
375376
)
377+
logger.debug(
378+
f"Drive id status: {len(drive_id_status)}, user email: {thread_id},"
379+
f"processed drive ids: {len(completion.processed_drive_ids)}"
380+
)
376381
# wake up other threads waiting for work
377382
cv.notify_all()
378383

@@ -423,13 +428,15 @@ def _impersonate_user_for_retrieval(
423428
curr_stage = checkpoint.completion_map[user_email]
424429
resuming = True
425430
if curr_stage.stage == DriveRetrievalStage.START:
431+
logger.info(f"Setting stage to {DriveRetrievalStage.MY_DRIVE_FILES.value}")
426432
curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES
427433
resuming = False
428434
drive_service = get_drive_service(self.creds, user_email)
429435

430436
# validate that the user has access to the drive APIs by performing a simple
431437
# request and checking for a 401
432438
try:
439+
logger.debug(f"Getting root folder id for user {user_email}")
433440
# default is ~17mins of retries, don't do that here for cases so we don't
434441
# waste 17mins everytime we run into a user without access to drive APIs
435442
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
@@ -445,14 +452,29 @@ def _impersonate_user_for_retrieval(
445452
curr_stage.stage = DriveRetrievalStage.DONE
446453
return
447454
raise
448-
455+
except RefreshError as e:
456+
logger.warning(
457+
f"User '{user_email}' could not refresh their token. Error: {e}"
458+
)
459+
# mark this user as done so we don't try to retrieve anything for them
460+
# again
461+
yield RetrievedDriveFile(
462+
completion_stage=DriveRetrievalStage.DONE,
463+
drive_file={},
464+
user_email=user_email,
465+
error=e,
466+
)
467+
curr_stage.stage = DriveRetrievalStage.DONE
468+
return
449469
# if we are including my drives, try to get the current user's my
450470
# drive if any of the following are true:
451471
# - include_my_drives is true
452472
# - the current user's email is in the requested emails
453473
if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES:
454474
if self.include_my_drives or user_email in self._requested_my_drive_emails:
455-
logger.info(f"Getting all files in my drive as '{user_email}'")
475+
logger.info(
476+
f"Getting all files in my drive as '{user_email}. Resuming: {resuming}"
477+
)
456478

457479
yield from add_retrieval_info(
458480
get_all_files_in_my_drive_and_shared(
@@ -505,7 +527,7 @@ def _yield_from_drive(
505527

506528
for drive_id in concurrent_drive_itr(user_email):
507529
logger.info(
508-
f"Getting files in shared drive '{drive_id}' as '{user_email}'"
530+
f"Getting files in shared drive '{drive_id}' as '{user_email}. Resuming: {resuming}"
509531
)
510532
curr_stage.completed_until = 0
511533
curr_stage.current_folder_or_drive_id = drive_id
@@ -577,6 +599,14 @@ def _manage_service_account_retrieval(
577599
start: SecondsSinceUnixEpoch | None = None,
578600
end: SecondsSinceUnixEpoch | None = None,
579601
) -> Iterator[RetrievedDriveFile]:
602+
"""
603+
The current implementation of the service account retrieval does some
604+
initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS
605+
concurrent threads, each of which impersonates a different user and retrieves
606+
files for that user. Technically, the actual work each thread does is "yield the
607+
next file retrieved by the user", at which point it returns to the thread pool;
608+
see parallel_yield for more details.
609+
"""
580610
if checkpoint.completion_stage == DriveRetrievalStage.START:
581611
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS
582612

@@ -602,6 +632,7 @@ def _manage_service_account_retrieval(
602632
checkpoint.completion_map[email] = StageCompletion(
603633
stage=DriveRetrievalStage.START,
604634
completed_until=0,
635+
processed_drive_ids=set(),
605636
)
606637

607638
# we've found all users and drives, now time to actually start
@@ -627,7 +658,7 @@ def _manage_service_account_retrieval(
627658
# to the drive APIs. Without this, we could loop through these emails for
628659
# more than 3 hours, causing a timeout and stalling progress.
629660
email_batch_takes_us_to_completion = True
630-
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
661+
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS
631662
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
632663
non_completed_org_emails = non_completed_org_emails[
633664
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
@@ -871,6 +902,10 @@ def _checkpointed_retrieval(
871902
return
872903

873904
for file in drive_files:
905+
logger.debug(
906+
f"Updating checkpoint for file: {file.drive_file.get('name')}. "
907+
f"Seen: {file.drive_file.get('id') in checkpoint.all_retrieved_file_ids}"
908+
)
874909
checkpoint.completion_map[file.user_email].update(
875910
stage=file.completion_stage,
876911
completed_until=datetime.fromisoformat(
@@ -1047,24 +1082,22 @@ def _yield_batch(
10471082
continue
10481083
files_batch.append(retrieved_file)
10491084

1050-
if len(files_batch) < self.batch_size:
1085+
if len(files_batch) < DRIVE_BATCH_SIZE:
10511086
continue
10521087

1053-
logger.info(
1054-
f"Yielding batch of {len(files_batch)} files; num seen doc ids: {len(checkpoint.all_retrieved_file_ids)}"
1055-
)
10561088
yield from _yield_batch(files_batch)
10571089
files_batch = []
10581090

1059-
if batches_complete > BATCHES_PER_CHECKPOINT:
1060-
checkpoint.retrieved_folder_and_drive_ids = (
1061-
self._retrieved_folder_and_drive_ids
1062-
)
1063-
return # create a new checkpoint
1064-
1091+
logger.info(
1092+
f"Processing remaining files: {[file.drive_file.get('name') for file in files_batch]}"
1093+
)
10651094
# Process any remaining files
10661095
if files_batch:
10671096
yield from _yield_batch(files_batch)
1097+
checkpoint.retrieved_folder_and_drive_ids = (
1098+
self._retrieved_folder_and_drive_ids
1099+
)
1100+
10681101
except Exception as e:
10691102
logger.exception(f"Error extracting documents from Google Drive: {e}")
10701103
raise e
@@ -1083,6 +1116,10 @@ def load_from_checkpoint(
10831116
"Credentials missing, should not call this method before calling load_credentials"
10841117
)
10851118

1119+
logger.info(
1120+
f"Loading from checkpoint with completion stage: {checkpoint.completion_stage},"
1121+
f"num retrieved ids: {len(checkpoint.all_retrieved_file_ids)}"
1122+
)
10861123
checkpoint = copy.deepcopy(checkpoint)
10871124
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
10881125
try:

backend/onyx/connectors/google_drive/doc_conversion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,16 @@ def convert_drive_item_to_document(
327327
doc_or_failure = _convert_drive_item_to_document(
328328
creds, allow_images, size_threshold, retriever_email, file
329329
)
330+
331+
# There are a variety of permissions-based errors that occasionally occur
332+
# when retrieving files. Often when these occur, there is another user
333+
# that can successfully retrieve the file, so we try the next user.
330334
if (
331335
doc_or_failure is None
332336
or isinstance(doc_or_failure, Document)
333337
or not (
334338
isinstance(doc_or_failure.exception, HttpError)
335-
and doc_or_failure.exception.status_code in [403, 404]
339+
and doc_or_failure.exception.status_code in [401, 403, 404]
336340
)
337341
):
338342
return doc_or_failure

0 commit comments

Comments
 (0)