Skip to content
116 changes: 113 additions & 3 deletions backend/onyx/connectors/github/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from github import Github
from github import RateLimitExceededException
from github import Repository
from github.ContentFile import ContentFile
from github.GithubException import GithubException
from github.Issue import Issue
from github.NamedUser import NamedUser
Expand Down Expand Up @@ -343,6 +344,20 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)


def _convert_file_to_document(file: ContentFile) -> Document:
return Document(
id=file.html_url,
sections=[TextSection(link=file.html_url, text=file.content or "")],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: ContentFile.content may be base64 encoded for binary files or None for large files. Should decode content and handle these cases

source=DocumentSource.GITHUB,
semantic_identifier=f"{file.repository.full_name}/{file.path}",
metadata={
"object_type": "File",
"repo": file.repository.full_name if file.repository else "",
"path": file.path,
},
)


class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
Expand All @@ -359,6 +374,7 @@ class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
FILES_MD = "files_md"


class GithubConnectorCheckpoint(ConnectorCheckpoint):
Expand Down Expand Up @@ -402,12 +418,14 @@ def __init__(
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
include_files_md: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.include_files_md = include_files_md
self.github_client: Github | None = None

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
Expand Down Expand Up @@ -504,6 +522,61 @@ def _issues_func(
state=self.state_filter, sort="updated", direction="desc"
)

def _files_md_func(self, repo: Repository.Repository) -> list[ContentFile]:
github_client = cast(Github, self.github_client)

def _get_contents_rate_limited(
github_client: Github, path: str, attempt_num: int = 0
) -> list[ContentFile]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching contents too many times. Something is going wrong with fetching objects from Github"
)
try:
contents = repo.get_contents(path)
if isinstance(contents, ContentFile):
contents = [cast(ContentFile, contents)]
else:
contents = cast(list[ContentFile], contents)
return contents
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return _get_contents_rate_limited(github_client, path, attempt_num + 1)
except GithubException as e:
logger.error(f"Error accessing directory {path}: {e}")
return []

md_files = []
contents = _get_contents_rate_limited(github_client, "")

if isinstance(contents, ContentFile):
# if the contents is a single file or directory, we need to wrap it in a list
contents = [cast(ContentFile, contents)]
else:
contents = cast(list[ContentFile], contents)

while contents:
file = contents.pop(0)
if file.type == "dir":
try:
# if the file is a directory, we need to get the contents of the directory
# and add the contents to the contents list
new_contents = _get_contents_rate_limited(github_client, file.path)
if isinstance(new_contents, ContentFile):
new_contents = [cast(ContentFile, new_contents)]
else:
new_contents = cast(list[ContentFile], new_contents)

contents.extend(new_contents)
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
contents.append(file)
continue
elif file.type == "file" and file.name.endswith(".md"):
md_files.append(file)

return md_files

def _fetch_from_github(
self,
checkpoint: GithubConnectorCheckpoint,
Expand Down Expand Up @@ -587,6 +660,8 @@ def _fetch_from_github(
for pr in pr_batch:
num_prs += 1

pr = cast(PullRequest, pr)

# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
Expand All @@ -603,7 +678,7 @@ def _fetch_from_github(
):
continue
try:
yield _convert_pr_to_document(cast(PullRequest, pr))
yield _convert_pr_to_document(pr)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
Expand Down Expand Up @@ -659,14 +734,14 @@ def _fetch_from_github(
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
# we iterate backwards in time, so at this point we stop processing Issues
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_issues = True
break
# Skip PRs updated after the end date
# Skip Issues updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
Expand Down Expand Up @@ -700,9 +775,44 @@ def _fetch_from_github(

# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.FILES_MD
checkpoint.reset()

checkpoint.stage = GithubConnectorStage.FILES_MD

if self.include_files_md and checkpoint.stage == GithubConnectorStage.FILES_MD:
logger.info(f"Fetching Markdown files for repo: {repo.name}")

md_files = self._files_md_func(repo)

checkpoint.curr_page += 1
num_files_md = 0
for file in md_files:
num_files_md += 1
file = cast(ContentFile, file)
try:
yield _convert_file_to_document(file)
except Exception as e:
error_msg = f"Error converting Markdown file to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(file.html_url),
document_link=file.html_url,
),
failure_message=error_msg,
exception=e,
)

continue

logger.info(f"Fetched {num_files_md} Markdown files for repo: {repo.name}")
logger.info(f"Fetched {num_files_md} Markdown files for repo: {repo.name}")
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()

checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0

checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
Expand Down
18 changes: 17 additions & 1 deletion backend/tests/daily/connectors/github/test_github_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.models import Document
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector


Expand All @@ -15,6 +16,7 @@ def github_connector() -> GithubConnector:
repositories="documentation",
include_prs=True,
include_issues=True,
include_files_md=True,
)
connector.load_credentials(
{
Expand All @@ -32,9 +34,16 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None:
)
assert len(docs) > 1 # We expect at least one PR and one Issue to exist

def get_issue_doc(docs: list[Document]) -> Document | None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to add a docstring here as well

for doc in docs:
if doc.metadata.get("object_type") == "Issue":
return doc
return None

# Test the first document's structure
pr_doc = docs[0]
issue_doc = docs[-1]
issue_doc = get_issue_doc(docs)
file_doc = docs[-1]

# Verify basic document properties
assert pr_doc.source == DocumentSource.GITHUB
Expand All @@ -60,6 +69,7 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None:
assert "created_at" in pr_doc.metadata

# Verify Issue-specific properties
assert issue_doc is not None, "Issue document not found"
assert issue_doc.metadata is not None
assert issue_doc.metadata.get("object_type") == "Issue"
assert "id" in issue_doc.metadata
Expand All @@ -70,6 +80,12 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None:
assert "labels" in issue_doc.metadata
assert "created_at" in issue_doc.metadata

# Verify File-specific properties
assert file_doc.metadata is not None
assert file_doc.metadata.get("object_type") == "File"
assert "repo" in file_doc.metadata
assert "path" in file_doc.metadata

# Verify sections
assert len(pr_doc.sections) == 1
section = pr_doc.sections[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from github import Github
from github import RateLimitExceededException
from github.ContentFile import ContentFile
from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
Expand Down Expand Up @@ -65,6 +66,7 @@ def _github_connector(
repositories=repositories,
include_prs=True,
include_issues=True,
include_files_md=True,
)
connector.github_client = mock_github_client
return connector
Expand Down Expand Up @@ -126,6 +128,26 @@ def _create_mock_issue(
return _create_mock_issue


@pytest.fixture
def create_mock_file() -> Callable[..., MagicMock]:
def _create_mock_file(
name: str = "README.md",
html_url: str = "https://github.com/onyx-dot-app/onyx/blob/main/README.md",
content: str = "# README",
path: str = "README.md",
type: str = "file",
) -> MagicMock:
mock_file = MagicMock(spec=ContentFile)
mock_file.name = name
mock_file.html_url = html_url
mock_file.content = content
mock_file.path = path
mock_file.type = type
return mock_file

return _create_mock_file


@pytest.fixture
def create_mock_repo() -> Callable[..., MagicMock]:
def _create_mock_repo(
Expand Down Expand Up @@ -162,6 +184,7 @@ def test_load_from_checkpoint_happy_path(
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
create_mock_issue: Callable[..., MagicMock],
create_mock_file: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked repo
Expand All @@ -170,11 +193,25 @@ def test_load_from_checkpoint_happy_path(
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo

# Set up mocked PRs and issues
# Set up mocked PRs, issues, and markdown files
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
mock_file1 = create_mock_file(
name="README.md",
html_url="https://github.com/test-org/test-repo/blob/main/README.md",
content="# README",
path="README.md",
type="file",
)
mock_file2 = create_mock_file(
name="CONTRIBUTING.md",
html_url="https://github.com/test-org/test-repo/blob/main/CONTRIBUTING.md",
content="# CONTRIBUTING.md",
path="CONTRIBUTING.md",
type="file",
)

# Mock get_pulls and get_issues methods
mock_repo.get_pulls.return_value = MagicMock()
Expand All @@ -187,6 +224,7 @@ def test_load_from_checkpoint_happy_path(
[mock_issue1, mock_issue2],
[],
]
mock_repo.get_contents.return_value = [mock_file1, mock_file2]

# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
Expand Down Expand Up @@ -225,10 +263,23 @@ def test_load_from_checkpoint_happy_path(
)
assert second_batch.next_checkpoint.has_more

# Check third batch (finished checkpoint)
# Check third batch (Markdown files)
third_batch = outputs[3]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
assert len(third_batch.items) == 2
assert isinstance(third_batch.items[0], Document)
assert (
third_batch.items[0].id
== "https://github.com/test-org/test-repo/blob/main/README.md"
)
assert isinstance(third_batch.items[1], Document)
assert (
third_batch.items[1].id
== "https://github.com/test-org/test-repo/blob/main/CONTRIBUTING.md"
)

# Check final batch (finished checkpoint)
final_batch = outputs[-1]
assert final_batch.next_checkpoint.has_more is False


def test_load_from_checkpoint_with_rate_limit(
Expand Down Expand Up @@ -294,11 +345,12 @@ def test_load_from_checkpoint_with_empty_repo(
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo

# Mock get_pulls and get_issues to return empty lists
# Mock get_pulls, get_issues, and get_contents to return empty lists
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.return_value = []
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.return_value = []
mock_repo.get_contents.return_value = []

# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
Expand All @@ -309,8 +361,8 @@ def test_load_from_checkpoint_with_empty_repo(
)

# Check that we got no documents
assert len(outputs) == 2
assert len(outputs[-1].items) == 0
assert len(outputs) == 3
assert len(outputs[1].items) == 0
assert not outputs[-1].next_checkpoint.has_more


Expand Down Expand Up @@ -886,7 +938,7 @@ def to_repository_side_effect(
assert cp4.cached_repo is not None
assert cp4.cached_repo.id == mock_repo1.id # Last processed repo
assert (
cp4.stage == GithubConnectorStage.PRS
cp4.stage == GithubConnectorStage.FILES_MD
) # Reset for a hypothetical next run/repo
assert cp4.curr_page == 0
assert cp4.num_retrieved == 0
Expand Down
Loading
Loading