Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions backend/onyx/connectors/github/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
logger = setup_logger()

ITEMS_PER_PAGE = 100
CURSOR_LOG_FREQUENCY = 100
CURSOR_LOG_FREQUENCY = 50

_MAX_NUM_RATE_LIMIT_RETRIES = 5

Expand Down Expand Up @@ -118,7 +118,7 @@ def _paginate_until_error(
"This will retrieve all pages before the one we are resuming from, "
"which may take a while and consume many API calls."
)
pag_list = pag_list[prev_num_objs:]
pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:])
num_objs = 0

try:
Expand Down Expand Up @@ -297,6 +297,19 @@ def reset(self) -> None:
self.cursor_url = None


def make_cursor_url_callback(
checkpoint: GithubConnectorCheckpoint,
) -> Callable[[str | None, int], None]:
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
# we want to maintain the old cursor url so code after retrieval
# can determine that we are using the fallback cursor-based pagination strategy
if cursor_url:
checkpoint.cursor_url = cursor_url
checkpoint.num_retrieved = num_objs

return cursor_url_callback


class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
def __init__(
self,
Expand Down Expand Up @@ -393,6 +406,20 @@ def _get_all_repos(
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)

def _pull_requests_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[PullRequest]]:
return lambda: repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)

def _issues_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[Issue]]:
return lambda: repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)

def _fetch_from_github(
self,
checkpoint: GithubConnectorCheckpoint,
Expand Down Expand Up @@ -433,7 +460,8 @@ def _fetch_from_github(
# save checkpoint with repo ids retrieved
return checkpoint

assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
if checkpoint.cached_repo is None:
raise ValueError("No repo saved in checkpoint")

# Try to access the requester - different PyGithub versions may use different attribute names
try:
Expand All @@ -455,22 +483,13 @@ def _fetch_from_github(
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)

def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
checkpoint.cursor_url = cursor_url
checkpoint.num_retrieved = num_objs
cursor_url_callback = make_cursor_url_callback(checkpoint)

# TODO: all PRs are also issues, so we should be able to _only_ get issues
# and then filter appropriately whenever include_issues is True
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")

def pull_requests_func() -> PaginatedList[PullRequest]:
return repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)

pr_batch = _get_batch_rate_limited(
pull_requests_func,
self._pull_requests_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
Expand Down Expand Up @@ -521,15 +540,17 @@ def pull_requests_func() -> PaginatedList[PullRequest]:
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
# In offset mode, while indexing without time constraints, the pr batch
# will be empty when we're done.
if num_prs > 0 and not done_with_prs and not checkpoint.cursor_url:
used_cursor = checkpoint.cursor_url is not None
logger.info(f"Fetched {num_prs} PRs for repo: {repo.name}")
if num_prs > 0 and not done_with_prs and not used_cursor:
return checkpoint

# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.reset()

if checkpoint.cursor_url:
if used_cursor:
# save the checkpoint after changing stage; next run will continue from issues
return checkpoint

Expand All @@ -538,14 +559,9 @@ def pull_requests_func() -> PaginatedList[PullRequest]:
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")

def issues_func() -> PaginatedList[Issue]:
return repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)

issue_batch = list(
_get_batch_rate_limited(
issues_func,
self._issues_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
Expand Down Expand Up @@ -575,7 +591,6 @@ def issues_func() -> PaginatedList[Issue]:

if issue.pull_request is not None:
# PRs are handled separately
# TODO: but they shouldn't always be
continue

try:
Expand All @@ -593,6 +608,7 @@ def issues_func() -> PaginatedList[Issue]:
)
continue

logger.info(f"Fetched {num_issues} issues for repo: {repo.name}")
# if we found any issues on the page, and we're not done, return the checkpoint.
# don't return if we're using cursor-based pagination to avoid infinite loops
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
Expand All @@ -613,6 +629,8 @@ def issues_func() -> PaginatedList[Issue]:
raw_data=next_repo.raw_data,
)

logger.info(f"{len(checkpoint.cached_repo_ids)} repos remaining")

return checkpoint

@override
Expand Down
2 changes: 1 addition & 1 deletion backend/requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ faker==37.1.0
lxml==5.3.0
lxml_html_clean==0.2.2
mypy-extensions==1.0.0
mypy==1.15.0
mypy==1.13.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
posthog==3.7.4
Expand Down
Loading