Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions backend/onyx/connectors/confluence/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _should_propagate_error(e: Exception) -> bool:

class ConfluenceCheckpoint(ConnectorCheckpoint):
last_updated: SecondsSinceUnixEpoch
last_seen_doc_ids: list[str]


class ConfluenceConnector(
Expand Down Expand Up @@ -108,7 +109,6 @@ def __init__(
self.index_recursively = index_recursively
self.cql_query = cql_query
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
Expand Down Expand Up @@ -159,6 +159,9 @@ def __init__(
"max_backoff_seconds": 60,
}

# deprecated
self.continue_on_failure = continue_on_failure

def set_allow_images(self, value: bool) -> None:
logger.info(f"Setting allow_images to {value}.")
self.allow_images = value
Expand Down Expand Up @@ -417,18 +420,16 @@ def _fetch_page_attachments(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if not self.continue_on_failure:
if _should_propagate_error(e):
raise
# TODO: should we remove continue_on_failure entirely now that we have checkpointing?
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc.id,
document_link=object_url,
),
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {doc.id}",
exception=e,
)
if _should_propagate_error(e):
raise
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc.id,
document_link=object_url,
),
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {doc.id}",
exception=e,
)
return doc

def _fetch_document_batches(
Expand All @@ -447,16 +448,23 @@ def _fetch_document_batches(
doc_count = 0

checkpoint = copy.deepcopy(checkpoint)

prev_doc_ids = checkpoint.last_seen_doc_ids
checkpoint.last_seen_doc_ids = []
# use "start" when last_updated is 0
page_query = self._construct_page_query(checkpoint.last_updated or start, end)
logger.debug(f"page_query: {page_query}")

# most requests will include a few pages to skip, so we limit each page to
# 2 * batch_size to only need a single request for most checkpoint runs
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
limit=2 * self.batch_size,
):
if page["id"] in prev_doc_ids:
# There are a few seconds of fuzziness in the request,
# so we skip if we saw this page on the last run
continue
# Build doc from page
doc_or_failure = self._convert_page_to_document(page)

Expand All @@ -477,6 +485,7 @@ def _fetch_document_batches(

# yield completed document
doc_count += 1
checkpoint.last_seen_doc_ids.append(page["id"])
yield doc_or_failure

# create checkpoint after enough documents have been processed
Expand Down Expand Up @@ -507,7 +516,7 @@ def load_from_checkpoint(

@override
def build_dummy_checkpoint(self) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint(last_updated=0, has_more=True)
return ConfluenceCheckpoint(last_updated=0, has_more=True, last_seen_doc_ids=[])

@override
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
from tests.unit.onyx.connectors.utils import (
load_everything_from_checkpoint_connector_from_checkpoint,
)

PAGE_SIZE = 2

Expand Down Expand Up @@ -175,6 +178,7 @@ def test_load_from_checkpoint_happy_path(
assert checkpoint_output1.next_checkpoint == ConfluenceCheckpoint(
last_updated=first_updated.timestamp(),
has_more=True,
last_seen_doc_ids=["1", "2"],
)

checkpoint_output2 = outputs[1]
Expand All @@ -183,8 +187,7 @@ def test_load_from_checkpoint_happy_path(
assert isinstance(document3, Document)
assert document3.id == f"{confluence_connector.wiki_base}/spaces/TEST/pages/3"
assert checkpoint_output2.next_checkpoint == ConfluenceCheckpoint(
last_updated=last_updated.timestamp(),
has_more=False,
last_updated=last_updated.timestamp(), has_more=False, last_seen_doc_ids=["3"]
)


Expand Down Expand Up @@ -332,7 +335,8 @@ def test_checkpoint_progress(
confluence_connector: ConfluenceConnector,
create_mock_page: Callable[..., dict[str, Any]],
) -> None:
"""Test that the checkpoint's last_updated field is properly updated after processing pages"""
"""Test that the checkpoint's last_updated field is properly updated after processing pages
and that processed document IDs are stored to avoid reprocessing."""
# Set up mocked pages with different timestamps
earlier_timestamp = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc)
later_timestamp = datetime(2023, 1, 2, 12, 0, tzinfo=timezone.utc)
Expand All @@ -356,28 +360,61 @@ def test_checkpoint_progress(
[], # No more pages
]

# Call load_from_checkpoint
# First run - process both pages
end_time = datetime(2023, 1, 3, tzinfo=timezone.utc).timestamp()

outputs = load_everything_from_checkpoint_connector(
confluence_connector, 0, end_time
)

assert len(outputs) == 2

first_checkpoint = outputs[0].next_checkpoint
last_checkpoint = outputs[-1].next_checkpoint

assert last_checkpoint == ConfluenceCheckpoint(
assert first_checkpoint == ConfluenceCheckpoint(
last_updated=later_timestamp.timestamp(),
has_more=False,
has_more=True,
last_seen_doc_ids=["1", "2"],
)
# Convert the expected timestamp to epoch seconds
expected_timestamp = datetime(2023, 1, 2, 12, 0, tzinfo=timezone.utc).timestamp()

# The checkpoint's last_updated should be set to the latest page's timestamp
assert last_checkpoint.last_updated == expected_timestamp
assert not last_checkpoint.has_more # No more pages to process
# Verify checkpoint contains both document IDs and latest timestamp
assert last_checkpoint == ConfluenceCheckpoint(
last_updated=later_timestamp.timestamp(), has_more=False, last_seen_doc_ids=[]
)

assert len(outputs) == 2
# Verify we got both documents
assert len(outputs[0].items) == 2
assert isinstance(outputs[0].items[0], Document)
assert outputs[0].items[0].semantic_identifier == "Page 1"
assert isinstance(outputs[0].items[1], Document)
assert outputs[0].items[1].semantic_identifier == "Page 2"

latest_timestamp = datetime(2024, 1, 2, 12, 0, tzinfo=timezone.utc)
mock_page3 = create_mock_page(
id="3", title="Page 3", updated=latest_timestamp.isoformat()
)
# Second run - same time range but with checkpoint from first run
# Reset the mock to return the same pages
paginated_cql_mock.side_effect = [
[mock_page1, mock_page2, mock_page3], # Return both pages
[], # No comments for page 1
[], # No attachments for page 1
[], # No comments for page 2
[], # No attachments for page 2
[], # No more pages
]

# Use the checkpoint from first run
outputs_with_checkpoint = load_everything_from_checkpoint_connector_from_checkpoint(
confluence_connector, 0, end_time, first_checkpoint
)

# Verify no documents were processed since they were in last_seen_doc_ids
assert len(outputs_with_checkpoint) == 1
assert len(outputs_with_checkpoint[0].items) == 1
assert isinstance(outputs_with_checkpoint[0].items[0], Document)
assert outputs_with_checkpoint[0].items[0].semantic_identifier == "Page 3"
assert outputs_with_checkpoint[0].next_checkpoint == ConfluenceCheckpoint(
last_updated=latest_timestamp.timestamp(),
has_more=False,
last_seen_doc_ids=["3"],
)
13 changes: 12 additions & 1 deletion backend/tests/unit/onyx/connectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,20 @@ def load_everything_from_checkpoint_connector(
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[SingleConnectorCallOutput[CT]]:
num_iterations = 0

checkpoint = cast(CT, connector.build_dummy_checkpoint())
return load_everything_from_checkpoint_connector_from_checkpoint(
connector, start, end, checkpoint
)


def load_everything_from_checkpoint_connector_from_checkpoint(
connector: CheckpointedConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CT,
) -> list[SingleConnectorCallOutput[CT]]:
num_iterations = 0
outputs: list[SingleConnectorCallOutput[CT]] = []
while checkpoint.has_more:
items: list[Document | ConnectorFailure] = []
Expand Down
Loading