Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/connectors/gemini_cloud_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

# mypy: disable-error-code="no-untyped-call,no-untyped-def,no-any-return,has-type,var-annotated"
import asyncio
import contextlib
import json
import logging
import os
Expand Down Expand Up @@ -836,6 +837,9 @@ async def _validate_project_access(self) -> None:
if logger.isEnabledFor(logging.ERROR):
logger.error(f"Failed to validate project access: {e}", exc_info=True)
raise
finally:
with contextlib.suppress(Exception):
auth_session.close()

async def _resolve_gemini_api_config(
self,
Expand Down Expand Up @@ -870,6 +874,7 @@ async def _resolve_gemini_api_config(

async def _perform_health_check(self) -> bool:
"""Perform a health check by testing API connectivity with project."""
session = None
try:
# With ADC, token handling is internal; proceed to simple request

Expand Down Expand Up @@ -922,6 +927,10 @@ async def _perform_health_check(self) -> bool:
f"Health check failed - unexpected error: {e}", exc_info=True
)
return False
finally:
if session is not None:
with contextlib.suppress(Exception):
session.close()

def _generate_user_prompt_id(self, request_data: Any) -> str:
"""Generate a unique user_prompt_id for Code Assist requests."""
Expand Down Expand Up @@ -1045,9 +1054,9 @@ async def _chat_completions_standard(
**kwargs: Any,
) -> ResponseEnvelope:
"""Handle non-streaming chat completions."""
auth_session = self._get_adc_authorized_session()
try:
# Use ADC for API calls (matches gemini CLI behavior for project-id auth)
auth_session = self._get_adc_authorized_session()

# Ensure project is onboarded for standard-tier
project_id = await self._ensure_project_onboarded(auth_session)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep ADC session creation inside error-handling block

Moving auth_session = self._get_adc_authorized_session() outside the surrounding try means any failure in acquiring Application Default Credentials (for example google.auth.exceptions.DefaultCredentialsError) will now escape without being logged or wrapped in the function’s BackendError fallback. Callers of _chat_completions_standard previously only saw BackendError/AuthenticationError, so this regression will leak raw exceptions and bypass the existing error logging. Consider creating the session inside the try and closing it in finally with a guard so the original error handling remains intact.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -1194,6 +1203,9 @@ async def _chat_completions_standard(
if logger.isEnabledFor(logging.ERROR):
logger.error(f"Unexpected error during API call: {e}", exc_info=True)
raise BackendError(f"Unexpected error during API call: {e}")
finally:
with contextlib.suppress(Exception):
auth_session.close()

async def _chat_completions_streaming(
self,
Expand All @@ -1203,9 +1215,10 @@ async def _chat_completions_streaming(
**kwargs: Any,
) -> StreamingResponseEnvelope:
"""Handle streaming chat completions."""
auth_session = self._get_adc_authorized_session()
stream_prepared = False
try:
# Use ADC for streaming API calls
auth_session = self._get_adc_authorized_session()

# Ensure project is onboarded for standard-tier
project_id = await self._ensure_project_onboarded(auth_session)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Streaming path now bypasses error wrapping for ADC session failures

The streaming variant now obtains the ADC AuthorizedSession before entering the try/except. If _get_adc_authorized_session() raises (e.g. missing credentials or networking errors), the exception bypasses the function’s error logging and conversion to BackendError, and will propagate as a raw Exception. This diverges from prior behaviour and from the non-streaming code paths that callers rely on for consistent error semantics. Acquire the session inside the try and keep the finally guarded so it is still closed when an error occurs.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -1387,9 +1400,13 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]:
finally:
if response: # Ensure response is defined before closing
response.close() # Use synchronous close
with contextlib.suppress(Exception):
auth_session.close()

generator = stream_generator()
stream_prepared = True
return StreamingResponseEnvelope(
content=stream_generator(),
content=generator,
media_type="text/event-stream",
headers={},
)
Expand All @@ -1402,6 +1419,10 @@ async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]:
f"Unexpected error during streaming API call: {e}", exc_info=True
)
raise BackendError(f"Unexpected error during streaming API call: {e}")
finally:
if not stream_prepared:
with contextlib.suppress(Exception):
auth_session.close()

def _build_generation_config(self, request_data: Any) -> dict[str, Any]:
cfg: dict[str, Any] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
from typing import Any

import httpx
import pytest

from src.connectors.gemini_cloud_project import GeminiCloudProjectConnector
from src.core.config.app_config import AppConfig
from src.core.services.translation_service import TranslationService


class _DummyResponse:
def __init__(self, status_code: int = 200, json_data: dict[str, Any] | None = None) -> None:
self.status_code = status_code
self._json_data = json_data or {}
self.text = ""

def json(self) -> dict[str, Any]:
return self._json_data


class _DummySession:
def __init__(self) -> None:
self.closed = False

def close(self) -> None:
self.closed = True


@pytest.fixture()
def connector() -> GeminiCloudProjectConnector:
cfg = AppConfig()
client = httpx.AsyncClient()
backend = GeminiCloudProjectConnector(
client,
cfg,
translation_service=TranslationService(),
gcp_project_id="test-project",
)
backend.gemini_api_base_url = "https://example.com"
return backend
Comment on lines +30 to +41
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure the AsyncClient is closed in the fixture

httpx.AsyncClient keeps background resources alive until aclose() runs; leaving it open in the fixture causes resource leaks and runtime warnings after the tests complete. Convert the fixture to an async fixture that async with-closes the client so each test frees the session.

-@pytest.fixture()
-def connector() -> GeminiCloudProjectConnector:
-    cfg = AppConfig()
-    client = httpx.AsyncClient()
-    backend = GeminiCloudProjectConnector(
-        client,
-        cfg,
-        translation_service=TranslationService(),
-        gcp_project_id="test-project",
-    )
-    backend.gemini_api_base_url = "https://example.com"
-    return backend
+@pytest_asyncio.fixture()
+async def connector() -> GeminiCloudProjectConnector:
+    cfg = AppConfig()
+    async with httpx.AsyncClient() as client:
+        backend = GeminiCloudProjectConnector(
+            client,
+            cfg,
+            translation_service=TranslationService(),
+            gcp_project_id="test-project",
+        )
+        backend.gemini_api_base_url = "https://example.com"
+        yield backend

Remember to import pytest_asyncio alongside the other pytest imports.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.fixture()
def connector() -> GeminiCloudProjectConnector:
cfg = AppConfig()
client = httpx.AsyncClient()
backend = GeminiCloudProjectConnector(
client,
cfg,
translation_service=TranslationService(),
gcp_project_id="test-project",
)
backend.gemini_api_base_url = "https://example.com"
return backend
@pytest_asyncio.fixture()
async def connector() -> GeminiCloudProjectConnector:
cfg = AppConfig()
async with httpx.AsyncClient() as client:
backend = GeminiCloudProjectConnector(
client,
cfg,
translation_service=TranslationService(),
gcp_project_id="test-project",
)
backend.gemini_api_base_url = "https://example.com"
yield backend
🤖 Prompt for AI Agents
In tests/unit/connectors/test_gemini_cloud_project_resource_management.py around
lines 30 to 41, the fixture creates an httpx.AsyncClient but never closes it;
convert the fixture to an async fixture (use pytest_asyncio import) and create
the client with an async context manager (async with httpx.AsyncClient() as
client:) or call await client.aclose() in a finally block, then yield the
GeminiCloudProjectConnector instance so tests receive it and the AsyncClient is
properly closed after each test.



@pytest.mark.asyncio
async def test_validate_project_access_closes_session(
connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch
) -> None:
class _Session(_DummySession):
def request(self, *args: Any, **kwargs: Any) -> _DummyResponse:
return _DummyResponse(
json_data={"cloudaicompanionProject": {"id": connector.gcp_project_id}}
)

session = _Session()

async def _immediate_to_thread(func: Any, *args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session)
monkeypatch.setattr(asyncio, "to_thread", _immediate_to_thread)

await connector._validate_project_access()

assert session.closed is True


@pytest.mark.asyncio
async def test_perform_health_check_closes_session(
connector: GeminiCloudProjectConnector, monkeypatch: pytest.MonkeyPatch
) -> None:
class _Credentials:
def __init__(self) -> None:
self.token = "token"

def refresh(self, request: Any) -> None: # pragma: no cover - simple stub
self.token = "new-token"

class _Session(_DummySession):
def __init__(self) -> None:
super().__init__()
self.credentials = _Credentials()

async def _fake_get(url: str, headers: dict[str, str], timeout: float) -> Any:
return _DummyResponse(status_code=200)

session = _Session()

monkeypatch.setattr(connector, "_get_adc_authorized_session", lambda: session)
monkeypatch.setattr(connector.client, "get", _fake_get)

result = await connector._perform_health_check()

assert result is True
assert session.closed is True
Loading