Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 81 additions & 4 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
from contextlib import asynccontextmanager
import json
import logging
import os
from pathlib import Path
Expand All @@ -32,7 +33,6 @@
from fastapi import HTTPException
from fastapi import Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -53,7 +53,6 @@
from ..agents import RunConfig
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts.gcs_artifact_service import GcsArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
Expand All @@ -65,8 +64,6 @@
from ..evaluation.eval_metrics import EvalMetricResult
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
from ..evaluation.eval_result import EvalSetResult
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..events.event import Event
Expand Down Expand Up @@ -965,6 +962,86 @@ async def _get_runner_async(app_name: str) -> Runner:
runner_dict[app_name] = runner
return runner

if a2a:
try:
from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import AgentCard

from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor

except ImportError as e:
import sys

if sys.version_info < (3, 10):
raise ImportError(
"A2A requires Python 3.10 or above. Please upgrade your Python"
" version."
) from e
else:
raise e
# locate all a2a agent apps in the agents directory
base_path = Path.cwd() / agents_dir
# the root agents directory should be an existing folder
if base_path.exists() and base_path.is_dir():
a2a_task_store = InMemoryTaskStore()

def create_a2a_runner_loader(captured_app_name: str):
"""Factory function to create A2A runner with proper closure."""

async def _get_a2a_runner_async() -> Runner:
return await _get_runner_async(captured_app_name)

return _get_a2a_runner_async

for p in base_path.iterdir():
# only folders with an agent.json file representing agent card are valid
# a2a agents
if (
p.is_file()
or p.name.startswith((".", "__pycache__"))
or not (p / "agent.json").is_file()
):
continue

app_name = p.name
logger.info("Setting up A2A agent: %s", app_name)

try:
a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}"

agent_executor = A2aAgentExecutor(
runner=create_a2a_runner_loader(app_name),
)

request_handler = DefaultRequestHandler(
agent_executor=agent_executor, task_store=a2a_task_store
)

with (p / "agent.json").open("r", encoding="utf-8") as f:
data = json.load(f)
agent_card = AgentCard(**data)
agent_card.url = a2a_rpc_path

a2a_app = A2AStarletteApplication(
agent_card=agent_card,
http_handler=request_handler,
)

routes = a2a_app.routes(
rpc_url=f"/a2a/{app_name}",
agent_card_url=f"/a2a/{app_name}/.well-known/agent.json",
)

for new_route in routes:
app.router.routes.append(new_route)

logger.info("Successfully configured A2A agent: %s", app_name)

except Exception as e:
logger.error("Failed to setup A2A agent %s: %s", app_name, e)
# Continue with other agents even if one fails
if web:
import mimetypes

Expand Down
160 changes: 159 additions & 1 deletion tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

import asyncio
import json
import logging
import os
from pathlib import Path
import sys
import tempfile
import time
from typing import Any
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import patch

Expand Down Expand Up @@ -465,6 +469,9 @@ def test_app(
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=False, # Disable A2A for most tests
host="127.0.0.1",
port=8000,
)

# Create a TestClient that doesn't start a real server
Expand Down Expand Up @@ -520,6 +527,134 @@ async def create_test_eval_set(
return test_session_info


@pytest.fixture
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def temp_agents_dir_with_a2a():
"""Create a temporary agents directory with A2A agent configurations for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create test agent directory
agent_dir = Path(temp_dir) / "test_a2a_agent"
agent_dir.mkdir()

# Create agent.json file
agent_card = {
"name": "test_a2a_agent",
"description": "Test A2A agent",
"version": "1.0.0",
"author": "test",
"capabilities": ["text"],
}

with open(agent_dir / "agent.json", "w") as f:
json.dump(agent_card, f)

# Create a simple agent.py file
agent_py_content = """
from google.adk.agents.base_agent import BaseAgent

class TestA2AAgent(BaseAgent):
def __init__(self):
super().__init__(name="test_a2a_agent")
"""

with open(agent_dir / "agent.py", "w") as f:
f.write(agent_py_content)

yield temp_dir


@pytest.fixture
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_app_with_a2a(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
temp_agents_dir_with_a2a,
):
"""Create a TestClient for the FastAPI app with A2A enabled."""

# Mock A2A related classes
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service,
),
patch(
"google.adk.cli.fast_api.AgentLoader",
return_value=mock_agent_loader,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetsManager",
return_value=mock_eval_sets_manager,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
return_value=mock_eval_set_results_manager,
),
patch(
"google.adk.cli.cli_eval.run_evals",
new=mock_run_evals_for_fast_api,
),
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store,
patch(
"google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"
) as mock_executor,
patch(
"a2a.server.request_handlers.DefaultRequestHandler"
) as mock_handler,
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
):
# Configure mocks
mock_task_store.return_value = MagicMock()
mock_executor.return_value = MagicMock()
mock_handler.return_value = MagicMock()

# Mock A2AStarletteApplication
mock_app_instance = MagicMock()
mock_app_instance.routes.return_value = (
[]
) # Return empty routes for testing
mock_a2a_app.return_value = mock_app_instance

# Change to temp directory
original_cwd = os.getcwd()
os.chdir(temp_agents_dir_with_a2a)

try:
app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=True,
host="127.0.0.1",
port=8000,
)

client = TestClient(app)
yield client
finally:
os.chdir(original_cwd)


#################################################
# Test Cases
#################################################
Expand Down Expand Up @@ -760,5 +895,28 @@ def test_debug_trace(test_app):
logger.info("Debug trace test completed successfully")


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_a2a_agent_discovery(test_app_with_a2a):
"""Test that A2A agents are properly discovered and configured."""
# This test mainly verifies that the A2A setup doesn't break the app
response = test_app_with_a2a.get("/list-apps")
assert response.status_code == 200
logger.info("A2A agent discovery test passed")


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_a2a_disabled_by_default(test_app):
"""Test that A2A functionality is disabled by default."""
# The regular test_app fixture has a2a=False
# This test ensures no A2A routes are added
response = test_app.get("/list-apps")
assert response.status_code == 200
logger.info("A2A disabled by default test passed")


if __name__ == "__main__":
pytest.main(["-xvs", __file__])
Loading