10
10
from typing import Protocol
11
11
from urllib .parse import urlparse
12
12
13
+ from google .auth .exceptions import RefreshError # type: ignore
13
14
from google .oauth2 .credentials import Credentials as OAuthCredentials # type: ignore
14
15
from google .oauth2 .service_account import Credentials as ServiceAccountCredentials # type: ignore
15
16
from googleapiclient .errors import HttpError # type: ignore
72
73
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
73
74
# All file retrievals could be batched and made at once
74
75
75
- BATCHES_PER_CHECKPOINT = 10
76
+ BATCHES_PER_CHECKPOINT = 1
77
+
78
+ DRIVE_BATCH_SIZE = 80
76
79
77
80
78
81
def _extract_str_list_from_comma_str (string : str | None ) -> list [str ]:
@@ -184,8 +187,6 @@ def __init__(
184
187
"shared_folder_urls, or my_drive_emails"
185
188
)
186
189
187
- self .batch_size = batch_size
188
-
189
190
specific_requests_made = False
190
191
if bool (shared_drive_urls ) or bool (my_drive_emails ) or bool (shared_folder_urls ):
191
192
specific_requests_made = True
@@ -306,14 +307,14 @@ def _get_all_user_emails(self) -> list[str]:
306
307
return user_emails
307
308
308
309
def get_all_drive_ids (self ) -> set [str ]:
309
- primary_drive_service = get_drive_service (
310
- creds = self . creds ,
311
- user_email = self . primary_admin_email ,
312
- )
310
+ return self . _get_all_drives_for_user ( self . primary_admin_email )
311
+
312
+ def _get_all_drives_for_user ( self , user_email : str ) -> set [ str ]:
313
+ drive_service = get_drive_service ( self . creds , user_email )
313
314
is_service_account = isinstance (self .creds , ServiceAccountCredentials )
314
- all_drive_ids = set ()
315
+ all_drive_ids : set [ str ] = set ()
315
316
for drive in execute_paginated_retrieval (
316
- retrieval_function = primary_drive_service .drives ().list ,
317
+ retrieval_function = drive_service .drives ().list ,
317
318
list_key = "drives" ,
318
319
useDomainAdminAccess = is_service_account ,
319
320
fields = "drives(id),nextPageToken" ,
@@ -373,6 +374,10 @@ def record_drive_processing(drive_id: str) -> None:
373
374
if drive_id in self ._retrieved_folder_and_drive_ids
374
375
else DriveIdStatus .AVAILABLE
375
376
)
377
+ logger .debug (
378
+ f"Drive id status: { len (drive_id_status )} , user email: { thread_id } ,"
379
+ f"processed drive ids: { len (completion .processed_drive_ids )} "
380
+ )
376
381
# wake up other threads waiting for work
377
382
cv .notify_all ()
378
383
@@ -423,13 +428,15 @@ def _impersonate_user_for_retrieval(
423
428
curr_stage = checkpoint .completion_map [user_email ]
424
429
resuming = True
425
430
if curr_stage .stage == DriveRetrievalStage .START :
431
+ logger .info (f"Setting stage to { DriveRetrievalStage .MY_DRIVE_FILES .value } " )
426
432
curr_stage .stage = DriveRetrievalStage .MY_DRIVE_FILES
427
433
resuming = False
428
434
drive_service = get_drive_service (self .creds , user_email )
429
435
430
436
# validate that the user has access to the drive APIs by performing a simple
431
437
# request and checking for a 401
432
438
try :
439
+ logger .debug (f"Getting root folder id for user { user_email } " )
433
440
# default is ~17mins of retries, don't do that here for cases so we don't
434
441
# waste 17mins everytime we run into a user without access to drive APIs
435
442
retry_builder (tries = 3 , delay = 1 )(get_root_folder_id )(drive_service )
@@ -445,14 +452,29 @@ def _impersonate_user_for_retrieval(
445
452
curr_stage .stage = DriveRetrievalStage .DONE
446
453
return
447
454
raise
448
-
455
+ except RefreshError as e :
456
+ logger .warning (
457
+ f"User '{ user_email } ' could not refresh their token. Error: { e } "
458
+ )
459
+ # mark this user as done so we don't try to retrieve anything for them
460
+ # again
461
+ yield RetrievedDriveFile (
462
+ completion_stage = DriveRetrievalStage .DONE ,
463
+ drive_file = {},
464
+ user_email = user_email ,
465
+ error = e ,
466
+ )
467
+ curr_stage .stage = DriveRetrievalStage .DONE
468
+ return
449
469
# if we are including my drives, try to get the current user's my
450
470
# drive if any of the following are true:
451
471
# - include_my_drives is true
452
472
# - the current user's email is in the requested emails
453
473
if curr_stage .stage == DriveRetrievalStage .MY_DRIVE_FILES :
454
474
if self .include_my_drives or user_email in self ._requested_my_drive_emails :
455
- logger .info (f"Getting all files in my drive as '{ user_email } '" )
475
+ logger .info (
476
+ f"Getting all files in my drive as '{ user_email } . Resuming: { resuming } "
477
+ )
456
478
457
479
yield from add_retrieval_info (
458
480
get_all_files_in_my_drive_and_shared (
@@ -505,7 +527,7 @@ def _yield_from_drive(
505
527
506
528
for drive_id in concurrent_drive_itr (user_email ):
507
529
logger .info (
508
- f"Getting files in shared drive '{ drive_id } ' as '{ user_email } ' "
530
+ f"Getting files in shared drive '{ drive_id } ' as '{ user_email } . Resuming: { resuming } "
509
531
)
510
532
curr_stage .completed_until = 0
511
533
curr_stage .current_folder_or_drive_id = drive_id
@@ -577,6 +599,14 @@ def _manage_service_account_retrieval(
577
599
start : SecondsSinceUnixEpoch | None = None ,
578
600
end : SecondsSinceUnixEpoch | None = None ,
579
601
) -> Iterator [RetrievedDriveFile ]:
602
+ """
603
+ The current implementation of the service account retrieval does some
604
+ initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS
605
+ concurrent threads, each of which impersonates a different user and retrieves
606
+ files for that user. Technically, the actual work each thread does is "yield the
607
+ next file retrieved by the user", at which point it returns to the thread pool;
608
+ see parallel_yield for more details.
609
+ """
580
610
if checkpoint .completion_stage == DriveRetrievalStage .START :
581
611
checkpoint .completion_stage = DriveRetrievalStage .USER_EMAILS
582
612
@@ -602,6 +632,7 @@ def _manage_service_account_retrieval(
602
632
checkpoint .completion_map [email ] = StageCompletion (
603
633
stage = DriveRetrievalStage .START ,
604
634
completed_until = 0 ,
635
+ processed_drive_ids = set (),
605
636
)
606
637
607
638
# we've found all users and drives, now time to actually start
@@ -627,7 +658,7 @@ def _manage_service_account_retrieval(
627
658
# to the drive APIs. Without this, we could loop through these emails for
628
659
# more than 3 hours, causing a timeout and stalling progress.
629
660
email_batch_takes_us_to_completion = True
630
- MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
661
+ MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS
631
662
if len (non_completed_org_emails ) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING :
632
663
non_completed_org_emails = non_completed_org_emails [
633
664
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
@@ -871,6 +902,10 @@ def _checkpointed_retrieval(
871
902
return
872
903
873
904
for file in drive_files :
905
+ logger .debug (
906
+ f"Updating checkpoint for file: { file .drive_file .get ('name' )} . "
907
+ f"Seen: { file .drive_file .get ('id' ) in checkpoint .all_retrieved_file_ids } "
908
+ )
874
909
checkpoint .completion_map [file .user_email ].update (
875
910
stage = file .completion_stage ,
876
911
completed_until = datetime .fromisoformat (
@@ -1047,24 +1082,22 @@ def _yield_batch(
1047
1082
continue
1048
1083
files_batch .append (retrieved_file )
1049
1084
1050
- if len (files_batch ) < self . batch_size :
1085
+ if len (files_batch ) < DRIVE_BATCH_SIZE :
1051
1086
continue
1052
1087
1053
- logger .info (
1054
- f"Yielding batch of { len (files_batch )} files; num seen doc ids: { len (checkpoint .all_retrieved_file_ids )} "
1055
- )
1056
1088
yield from _yield_batch (files_batch )
1057
1089
files_batch = []
1058
1090
1059
- if batches_complete > BATCHES_PER_CHECKPOINT :
1060
- checkpoint .retrieved_folder_and_drive_ids = (
1061
- self ._retrieved_folder_and_drive_ids
1062
- )
1063
- return # create a new checkpoint
1064
-
1091
+ logger .info (
1092
+ f"Processing remaining files: { [file .drive_file .get ('name' ) for file in files_batch ]} "
1093
+ )
1065
1094
# Process any remaining files
1066
1095
if files_batch :
1067
1096
yield from _yield_batch (files_batch )
1097
+ checkpoint .retrieved_folder_and_drive_ids = (
1098
+ self ._retrieved_folder_and_drive_ids
1099
+ )
1100
+
1068
1101
except Exception as e :
1069
1102
logger .exception (f"Error extracting documents from Google Drive: { e } " )
1070
1103
raise e
@@ -1083,6 +1116,10 @@ def load_from_checkpoint(
1083
1116
"Credentials missing, should not call this method before calling load_credentials"
1084
1117
)
1085
1118
1119
+ logger .info (
1120
+ f"Loading from checkpoint with completion stage: { checkpoint .completion_stage } ,"
1121
+ f"num retrieved ids: { len (checkpoint .all_retrieved_file_ids )} "
1122
+ )
1086
1123
checkpoint = copy .deepcopy (checkpoint )
1087
1124
self ._retrieved_folder_and_drive_ids = checkpoint .retrieved_folder_and_drive_ids
1088
1125
try :
0 commit comments