diff --git a/backend/onyx/connectors/github/connector.py b/backend/onyx/connectors/github/connector.py index 34cc703c605..5eec620ba4d 100644 --- a/backend/onyx/connectors/github/connector.py +++ b/backend/onyx/connectors/github/connector.py @@ -12,6 +12,7 @@ from github import Github from github import RateLimitExceededException from github import Repository +from github.ContentFile import ContentFile from github.GithubException import GithubException from github.Issue import Issue from github.NamedUser import NamedUser @@ -343,6 +344,20 @@ def _convert_issue_to_document(issue: Issue) -> Document: ) +def _convert_file_to_document(file: ContentFile) -> Document: + return Document( + id=file.html_url, + sections=[TextSection(link=file.html_url, text=file.content or "")], + source=DocumentSource.GITHUB, + semantic_identifier=f"{file.repository.full_name}/{file.path}", + metadata={ + "object_type": "File", + "repo": file.repository.full_name if file.repository else "", + "path": file.path, + }, + ) + + class SerializedRepository(BaseModel): # id is part of the raw_data as well, just pulled out for convenience id: int @@ -359,6 +374,7 @@ class GithubConnectorStage(Enum): START = "start" PRS = "prs" ISSUES = "issues" + FILES_MD = "files_md" class GithubConnectorCheckpoint(ConnectorCheckpoint): @@ -402,12 +418,14 @@ def __init__( state_filter: str = "all", include_prs: bool = True, include_issues: bool = False, + include_files_md: bool = False, ) -> None: self.repo_owner = repo_owner self.repositories = repositories self.state_filter = state_filter self.include_prs = include_prs self.include_issues = include_issues + self.include_files_md = include_files_md self.github_client: Github | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: @@ -504,6 +522,61 @@ def _issues_func( state=self.state_filter, sort="updated", direction="desc" ) + def _files_md_func(self, repo: Repository.Repository) -> list[ContentFile]: + github_client = cast(Github, self.github_client) + + def _get_contents_rate_limited( + github_client: Github, path: str, attempt_num: int = 0 + ) -> list[ContentFile]: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching contents too many times. Something is going wrong with fetching objects from Github" + ) + try: + contents = repo.get_contents(path) + if isinstance(contents, ContentFile): + contents = [cast(ContentFile, contents)] + else: + contents = cast(list[ContentFile], contents) + return contents + except RateLimitExceededException: + _sleep_after_rate_limit_exception(github_client) + return _get_contents_rate_limited(github_client, path, attempt_num + 1) + except GithubException as e: + logger.error(f"Error accessing directory {path}: {e}") + return [] + + md_files = [] + contents = _get_contents_rate_limited(github_client, "") + + if isinstance(contents, ContentFile): + # if the contents is a single file or directory, we need to wrap it in a list + contents = [cast(ContentFile, contents)] + else: + contents = cast(list[ContentFile], contents) + + while contents: + file = contents.pop(0) + if file.type == "dir": + try: + # if the file is a directory, we need to get the contents of the directory + # and add the contents to the contents list + new_contents = _get_contents_rate_limited(github_client, file.path) + if isinstance(new_contents, ContentFile): + new_contents = [cast(ContentFile, new_contents)] + else: + new_contents = cast(list[ContentFile], new_contents) + + contents.extend(new_contents) + except RateLimitExceededException: + _sleep_after_rate_limit_exception(github_client) + contents.append(file) + continue + elif file.type == "file" and file.name.endswith(".md"): + md_files.append(file) + + return md_files + def _fetch_from_github( self, checkpoint: GithubConnectorCheckpoint, @@ -587,6 +660,8 @@ def _fetch_from_github( for pr in pr_batch: num_prs += 1 + pr = cast(PullRequest, pr) + # we iterate backwards in time, so at this point we stop processing prs if ( start is not None @@ -603,7 +678,7 @@ def _fetch_from_github( ): continue try: - yield _convert_pr_to_document(cast(PullRequest, pr)) + yield _convert_pr_to_document(pr) except Exception as e: error_msg = f"Error converting PR to document: {e}" logger.exception(error_msg) @@ -659,14 +734,14 @@ def _fetch_from_github( for issue in issue_batch: num_issues += 1 issue = cast(Issue, issue) - # we iterate backwards in time, so at this point we stop processing prs + # we iterate backwards in time, so at this point we stop processing Issues if ( start is not None and issue.updated_at.replace(tzinfo=timezone.utc) < start ): done_with_issues = True break - # Skip PRs updated after the end date + # Skip Issues updated after the end date if ( end is not None and issue.updated_at.replace(tzinfo=timezone.utc) > end @@ -700,9 +775,44 @@ def _fetch_from_github( # if we went past the start date during the loop or there are no more # issues to get, we move on to the next repo + checkpoint.stage = GithubConnectorStage.FILES_MD + checkpoint.reset() + + checkpoint.stage = GithubConnectorStage.FILES_MD + + if self.include_files_md and checkpoint.stage == GithubConnectorStage.FILES_MD: + logger.info(f"Fetching Markdown files for repo: {repo.name}") + + md_files = self._files_md_func(repo) + + checkpoint.curr_page += 1 + num_files_md = 0 + for file in md_files: + num_files_md += 1 + file = cast(ContentFile, file) + try: + yield _convert_file_to_document(file) + except Exception as e: + error_msg = f"Error converting Markdown file to document: {e}" + logger.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(file.html_url), + document_link=file.html_url, + ), + failure_message=error_msg, + exception=e, + ) + + continue + + logger.info(f"Fetched {num_files_md} Markdown files for repo: {repo.name}") + logger.info(f"Fetched {num_files_md} Markdown files for repo: {repo.name}") checkpoint.stage = GithubConnectorStage.PRS checkpoint.reset() + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 if checkpoint.cached_repo_ids: next_id = checkpoint.cached_repo_ids.pop() diff --git a/backend/tests/daily/connectors/github/test_github_basic.py b/backend/tests/daily/connectors/github/test_github_basic.py index 75ad30a1ca8..a048f3a2493 100644 --- a/backend/tests/daily/connectors/github/test_github_basic.py +++ b/backend/tests/daily/connectors/github/test_github_basic.py @@ -5,6 +5,7 @@ from onyx.configs.constants import DocumentSource from onyx.connectors.github.connector import GithubConnector +from onyx.connectors.models import Document from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector @@ -15,6 +16,7 @@ def github_connector() -> GithubConnector: repositories="documentation", include_prs=True, include_issues=True, + include_files_md=True, ) connector.load_credentials( { @@ -32,9 +34,16 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None: ) assert len(docs) > 1 # We expect at least one PR and one Issue to exist + def get_issue_doc(docs: list[Document]) -> Document | None: + for doc in docs: + if doc.metadata.get("object_type") == "Issue": + return doc + return None + # Test the first document's structure pr_doc = docs[0] - issue_doc = docs[-1] + issue_doc = get_issue_doc(docs) + file_doc = docs[-1] # Verify basic document properties assert pr_doc.source == DocumentSource.GITHUB @@ -60,6 +69,7 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None: assert "created_at" in pr_doc.metadata # Verify Issue-specific properties + assert issue_doc is not None, "Issue document not found" assert issue_doc.metadata is not None assert issue_doc.metadata.get("object_type") == "Issue" assert "id" in issue_doc.metadata @@ -70,6 +80,12 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None: assert "labels" in issue_doc.metadata assert "created_at" in issue_doc.metadata + # Verify File-specific properties + assert file_doc.metadata is not None + assert file_doc.metadata.get("object_type") == "File" + assert "repo" in file_doc.metadata + assert "path" in file_doc.metadata + # Verify sections assert len(pr_doc.sections) == 1 section = pr_doc.sections[0] diff --git a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py index e79f8f89a7a..5b9404b23f4 100644 --- a/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py +++ b/backend/tests/unit/onyx/connectors/github/test_github_checkpointing.py @@ -10,6 +10,7 @@ import pytest from github import Github from github import RateLimitExceededException +from github.ContentFile import ContentFile from github.GithubException import GithubException from github.Issue import Issue from github.PaginatedList import PaginatedList @@ -65,6 +66,7 @@ def _github_connector( repositories=repositories, include_prs=True, include_issues=True, + include_files_md=True, ) connector.github_client = mock_github_client return connector @@ -126,6 +128,26 @@ def _create_mock_issue( return _create_mock_issue +@pytest.fixture +def create_mock_file() -> Callable[..., MagicMock]: + def _create_mock_file( + name: str = "README.md", + html_url: str = "https://github.com/onyx-dot-app/onyx/blob/main/README.md", + content: str = "# README", + path: str = "README.md", + type: str = "file", + ) -> MagicMock: + mock_file = MagicMock(spec=ContentFile) + mock_file.name = name + mock_file.html_url = html_url + mock_file.content = content + mock_file.path = path + mock_file.type = type + return mock_file + + return _create_mock_file + + @pytest.fixture def create_mock_repo() -> Callable[..., MagicMock]: def _create_mock_repo( @@ -162,6 +184,7 @@ def test_load_from_checkpoint_happy_path( create_mock_repo: Callable[..., MagicMock], create_mock_pr: Callable[..., MagicMock], create_mock_issue: Callable[..., MagicMock], + create_mock_file: Callable[..., MagicMock], ) -> None: """Test loading from checkpoint - happy path""" # Set up mocked repo @@ -170,11 +193,25 @@ def test_load_from_checkpoint_happy_path( github_connector.github_client = mock_github_client mock_github_client.get_repo.return_value = mock_repo - # Set up mocked PRs and issues + # Set up mocked PRs, issues, and markdown files mock_pr1 = create_mock_pr(number=1, title="PR 1") mock_pr2 = create_mock_pr(number=2, title="PR 2") mock_issue1 = create_mock_issue(number=1, title="Issue 1") mock_issue2 = create_mock_issue(number=2, title="Issue 2") + mock_file1 = create_mock_file( + name="README.md", + html_url="https://github.com/test-org/test-repo/blob/main/README.md", + content="# README", + path="README.md", + type="file", + ) + mock_file2 = create_mock_file( + name="CONTRIBUTING.md", + html_url="https://github.com/test-org/test-repo/blob/main/CONTRIBUTING.md", + content="# CONTRIBUTING.md", + path="CONTRIBUTING.md", + type="file", + ) # Mock get_pulls and get_issues methods mock_repo.get_pulls.return_value = MagicMock() @@ -187,6 +224,7 @@ def test_load_from_checkpoint_happy_path( [mock_issue1, mock_issue2], [], ] + mock_repo.get_contents.return_value = [mock_file1, mock_file2] # Mock SerializedRepository.to_Repository to return our mock repo with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): @@ -225,10 +263,23 @@ def test_load_from_checkpoint_happy_path( ) assert second_batch.next_checkpoint.has_more - # Check third batch (finished checkpoint) + # Check third batch (Markdown files) third_batch = outputs[3] - assert len(third_batch.items) == 0 - assert third_batch.next_checkpoint.has_more is False + assert len(third_batch.items) == 2 + assert isinstance(third_batch.items[0], Document) + assert ( + third_batch.items[0].id + == "https://github.com/test-org/test-repo/blob/main/README.md" + ) + assert isinstance(third_batch.items[1], Document) + assert ( + third_batch.items[1].id + == "https://github.com/test-org/test-repo/blob/main/CONTRIBUTING.md" + ) + + # Check final batch (finished checkpoint) + final_batch = outputs[-1] + assert final_batch.next_checkpoint.has_more is False def test_load_from_checkpoint_with_rate_limit( @@ -294,11 +345,12 @@ def test_load_from_checkpoint_with_empty_repo( github_connector.github_client = mock_github_client mock_github_client.get_repo.return_value = mock_repo - # Mock get_pulls and get_issues to return empty lists + # Mock get_pulls, get_issues, and get_contents to return empty lists mock_repo.get_pulls.return_value = MagicMock() mock_repo.get_pulls.return_value.get_page.return_value = [] mock_repo.get_issues.return_value = MagicMock() mock_repo.get_issues.return_value.get_page.return_value = [] + mock_repo.get_contents.return_value = [] # Mock SerializedRepository.to_Repository to return our mock repo with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo): @@ -309,8 +361,8 @@ def test_load_from_checkpoint_with_empty_repo( ) # Check that we got no documents - assert len(outputs) == 2 - assert len(outputs[-1].items) == 0 + assert len(outputs) == 3 + assert len(outputs[1].items) == 0 assert not outputs[-1].next_checkpoint.has_more @@ -886,7 +938,7 @@ def to_repository_side_effect( assert cp4.cached_repo is not None assert cp4.cached_repo.id == mock_repo1.id # Last processed repo assert ( - cp4.stage == GithubConnectorStage.PRS + cp4.stage == GithubConnectorStage.FILES_MD ) # Reset for a hypothetical next run/repo assert cp4.curr_page == 0 assert cp4.num_retrieved == 0 diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 04ca5074a9e..9f624d62335 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -242,6 +242,14 @@ export const connectorConfigs: Record< description: "Index issues from repositories", optional: true, }, + { + type: "checkbox", + query: "Include Markdown files?", + label: "Include Markdown files?", + name: "include_files_md", + description: "Index Markdown files from repositories", + optional: true, + }, ], advanced_values: [], }, @@ -1492,6 +1500,7 @@ export interface GithubConfig { repositories: string; // Comma-separated list of repository names include_prs: boolean; include_issues: boolean; + include_files_md?: boolean; } export interface GitlabConfig {