|
1 | 1 | import io
|
| 2 | +import os |
2 | 3 | from collections.abc import Callable
|
3 | 4 | from datetime import datetime
|
4 | 5 | from typing import Any
|
5 | 6 | from typing import cast
|
6 | 7 |
|
| 8 | +import openai |
7 | 9 | from googleapiclient.errors import HttpError # type: ignore
|
8 | 10 | from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
9 | 11 | from pydantic import BaseModel
|
|
45 | 47 |
|
46 | 48 | logger = setup_logger()
|
47 | 49 |
|
| 50 | + |
| 51 | +class DocumentClassificationResult(BaseModel): |
| 52 | + categories: list[str] |
| 53 | + entities: list[str] |
| 54 | + |
| 55 | + |
48 | 56 | # This is not a standard valid unicode char, it is used by the docs advanced API to
|
49 | 57 | # represent smart chips (elements like dates and doc links).
|
50 | 58 | SMART_CHIP_CHAR = "\ue907"
|
@@ -406,6 +414,128 @@ def convert_drive_item_to_document(
|
406 | 414 | return first_error
|
407 | 415 |
|
408 | 416 |
|
| 417 | +def _extract_categories_and_entities( |
| 418 | + sections: list[TextSection | ImageSection], |
| 419 | +) -> dict[str, list[str]]: |
| 420 | + """Extract categories and entities from document sections with retry logic.""" |
| 421 | + import time |
| 422 | + import random |
| 423 | + |
| 424 | + prompt = """ |
| 425 | + Analyze this document, classify it with categories, and extract important entities. |
| 426 | +
|
| 427 | + CATEGORIES: |
| 428 | + Create up to 5 simple categories that best capture what this document is about. Consider categories within: |
| 429 | + - Document type (e.g., Manual, Report, Email, Transcript, etc.) |
| 430 | + - Content domain (e.g., Technical, Financial, HR, Marketing, etc.) |
| 431 | + - Purpose (e.g., Training, Reference, Announcement, Analysis, etc.) |
| 432 | + - Industry/Topic area (e.g., Software Development, Sales, Legal, etc.) |
| 433 | +
|
| 434 | + Be creative and specific. Use clear, descriptive terms that someone searching for this document might use. |
| 435 | + Categories should be up to 2 words each. |
| 436 | +
|
| 437 | + ENTITIES: |
| 438 | + Extract up to 5 important proper nouns, such as: |
| 439 | + - Company names (e.g., Microsoft, Google, Acme Corp) |
| 440 | + - Product names (e.g., Office 365, Salesforce, iPhone) |
| 441 | + - People's names (e.g. John, Jane, Ahmed, Wenjie, etc.) |
| 442 | + - Department names (e.g., Engineering, Marketing, HR) |
| 443 | + - Project names (e.g., Project Alpha, Migration 2024) |
| 444 | + - Technology names (e.g., PostgreSQL, React, AWS) |
| 445 | + - Location names (e.g., New York Office, Building A) |
| 446 | + """ |
| 447 | + |
| 448 | + # Retry configuration |
| 449 | + max_retries = 3 |
| 450 | + base_delay = 1.0 # seconds |
| 451 | + backoff_factor = 2.0 |
| 452 | + |
| 453 | + for attempt in range(max_retries + 1): |
| 454 | + try: |
| 455 | + api_key = os.getenv("OPENAI_API_KEY") |
| 456 | + if not api_key: |
| 457 | + logger.warning("OPENAI_API_KEY not set, skipping metadata extraction") |
| 458 | + return {"categories": [], "entities": []} |
| 459 | + |
| 460 | + client = openai.OpenAI(api_key=api_key) |
| 461 | + |
| 462 | + # Combine all section text |
| 463 | + document_text = "\n\n".join( |
| 464 | + [ |
| 465 | + section.text |
| 466 | + for section in sections |
| 467 | + if isinstance(section, TextSection) and section.text.strip() |
| 468 | + ] |
| 469 | + ) |
| 470 | + |
| 471 | + # Skip if no text content |
| 472 | + if not document_text.strip(): |
| 473 | + logger.debug("No text content found, skipping metadata extraction") |
| 474 | + return {"categories": [], "entities": []} |
| 475 | + |
| 476 | + # Truncate very long documents to avoid token limits |
| 477 | + max_chars = 50000 # Roughly 12k tokens |
| 478 | + if len(document_text) > max_chars: |
| 479 | + document_text = document_text[:max_chars] + "..." |
| 480 | + logger.debug(f"Truncated document text to {max_chars} characters") |
| 481 | + |
| 482 | + response = client.responses.parse( |
| 483 | + model="o3", |
| 484 | + input=[ |
| 485 | + { |
| 486 | + "role": "system", |
| 487 | + "content": "Extract categories and entities from the document.", |
| 488 | + }, |
| 489 | + { |
| 490 | + "role": "user", |
| 491 | + "content": prompt + "\n\nDOCUMENT: " + document_text, |
| 492 | + }, |
| 493 | + ], |
| 494 | + text_format=DocumentClassificationResult, |
| 495 | + ) |
| 496 | + |
| 497 | + classification_result = response.output_parsed |
| 498 | + |
| 499 | + result = { |
| 500 | + "categories": classification_result.categories, |
| 501 | + "entities": classification_result.entities, |
| 502 | + } |
| 503 | + |
| 504 | + logger.debug(f"Successfully extracted metadata: {result}") |
| 505 | + return result |
| 506 | + |
| 507 | + except Exception as e: |
| 508 | + attempt_num = attempt + 1 |
| 509 | + is_last_attempt = attempt == max_retries |
| 510 | + |
| 511 | + # Log the error |
| 512 | + if is_last_attempt: |
| 513 | + logger.error( |
| 514 | + f"Failed to extract categories and entities after {max_retries + 1} attempts: {e}" |
| 515 | + ) |
| 516 | + else: |
| 517 | + logger.warning( |
| 518 | + f"Attempt {attempt_num} failed to extract metadata: {e}. Retrying..." |
| 519 | + ) |
| 520 | + |
| 521 | + # If this is the last attempt, return empty results |
| 522 | + if is_last_attempt: |
| 523 | + return {"categories": [], "entities": []} |
| 524 | + |
| 525 | + # Calculate delay with exponential backoff and jitter |
| 526 | + delay = base_delay * (backoff_factor**attempt) |
| 527 | + jitter = random.uniform(0.1, 0.3) # Add 10-30% jitter |
| 528 | + total_delay = delay + jitter |
| 529 | + |
| 530 | + logger.debug( |
| 531 | + f"Waiting {total_delay:.2f} seconds before retry {attempt_num + 1}" |
| 532 | + ) |
| 533 | + time.sleep(total_delay) |
| 534 | + |
| 535 | + # Should never reach here, but just in case |
| 536 | + return {"categories": [], "entities": []} |
| 537 | + |
| 538 | + |
409 | 539 | def _convert_drive_item_to_document(
|
410 | 540 | creds: Any,
|
411 | 541 | allow_images: bool,
|
@@ -499,17 +629,23 @@ def _get_docs_service() -> GoogleDocsService:
|
499 | 629 | else None
|
500 | 630 | )
|
501 | 631 |
|
| 632 | + # Extract categories and entities from drive item and store in metadata |
| 633 | + categories_and_entities = _extract_categories_and_entities(sections) |
| 634 | + metadata = { |
| 635 | + "owner_names": ", ".join( |
| 636 | + owner.get("displayName", "") for owner in file.get("owners", []) |
| 637 | + ), |
| 638 | + "categories": categories_and_entities.get("categories", []), |
| 639 | + "entities": categories_and_entities.get("entities", []), |
| 640 | + } |
| 641 | + |
502 | 642 | # Create the document
|
503 | 643 | return Document(
|
504 | 644 | id=doc_id,
|
505 | 645 | sections=sections,
|
506 | 646 | source=DocumentSource.GOOGLE_DRIVE,
|
507 | 647 | semantic_identifier=file.get("name", ""),
|
508 |
| - metadata={ |
509 |
| - "owner_names": ", ".join( |
510 |
| - owner.get("displayName", "") for owner in file.get("owners", []) |
511 |
| - ), |
512 |
| - }, |
| 648 | + metadata=metadata, |
513 | 649 | doc_updated_at=datetime.fromisoformat(
|
514 | 650 | file.get("modifiedTime", "").replace("Z", "+00:00")
|
515 | 651 | ),
|
|
0 commit comments