|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +from datetime import datetime |
| 18 | +from datetime import timezone |
| 19 | +import inspect |
| 20 | +import logging |
| 21 | +from typing import Any |
| 22 | +from typing import Awaitable |
| 23 | +from typing import Callable |
| 24 | +from typing import Optional |
| 25 | +import uuid |
| 26 | + |
| 27 | +from a2a.server.agent_execution import AgentExecutor |
| 28 | +from a2a.server.agent_execution.context import RequestContext |
| 29 | +from a2a.server.events.event_queue import EventQueue |
| 30 | +from a2a.types import Message |
| 31 | +from a2a.types import Role |
| 32 | +from a2a.types import TaskState |
| 33 | +from a2a.types import TaskStatus |
| 34 | +from a2a.types import TaskStatusUpdateEvent |
| 35 | +from a2a.types import TextPart |
| 36 | +from google.adk.runners import Runner |
| 37 | +from pydantic import BaseModel |
| 38 | +from typing_extensions import override |
| 39 | + |
| 40 | +from ...utils.feature_decorator import working_in_progress |
| 41 | +from ..converters.event_converter import convert_event_to_a2a_events |
| 42 | +from ..converters.request_converter import convert_a2a_request_to_adk_run_args |
| 43 | +from ..converters.utils import _get_adk_metadata_key |
| 44 | +from .task_result_aggregator import TaskResultAggregator |
| 45 | + |
| 46 | +logger = logging.getLogger('google_adk.' + __name__) |
| 47 | + |
| 48 | + |
| 49 | +@working_in_progress |
| 50 | +class A2aAgentExecutorConfig(BaseModel): |
| 51 | + """Configuration for the A2aAgentExecutor.""" |
| 52 | + |
| 53 | + pass |
| 54 | + |
| 55 | + |
| 56 | +@working_in_progress |
| 57 | +class A2aAgentExecutor(AgentExecutor): |
| 58 | + """An AgentExecutor that runs an ADK Agent against an A2A request and |
| 59 | + publishes updates to an event queue. |
| 60 | + """ |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + *, |
| 65 | + runner: Runner | Callable[..., Runner | Awaitable[Runner]], |
| 66 | + config: Optional[A2aAgentExecutorConfig] = None, |
| 67 | + ): |
| 68 | + super().__init__() |
| 69 | + self._runner = runner |
| 70 | + self._config = config |
| 71 | + |
| 72 | + async def _resolve_runner(self) -> Runner: |
| 73 | + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" |
| 74 | + # If already resolved and cached, return it |
| 75 | + if isinstance(self._runner, Runner): |
| 76 | + return self._runner |
| 77 | + if callable(self._runner): |
| 78 | + # Call the function to get the runner |
| 79 | + result = self._runner() |
| 80 | + |
| 81 | + # Handle async callables |
| 82 | + if inspect.iscoroutine(result): |
| 83 | + resolved_runner = await result |
| 84 | + else: |
| 85 | + resolved_runner = result |
| 86 | + |
| 87 | + # Cache the resolved runner for future calls |
| 88 | + self._runner = resolved_runner |
| 89 | + return resolved_runner |
| 90 | + |
| 91 | + raise TypeError( |
| 92 | + 'Runner must be a Runner instance or a callable that returns a' |
| 93 | + f' Runner, got {type(self._runner)}' |
| 94 | + ) |
| 95 | + |
| 96 | + @override |
| 97 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 98 | + """Cancel the execution.""" |
| 99 | + # TODO: Implement proper cancellation logic if needed |
| 100 | + raise NotImplementedError('Cancellation is not supported') |
| 101 | + |
| 102 | + @override |
| 103 | + async def execute( |
| 104 | + self, |
| 105 | + context: RequestContext, |
| 106 | + event_queue: EventQueue, |
| 107 | + ): |
| 108 | + """Executes an A2A request and publishes updates to the event queue |
| 109 | + specified. It runs as following: |
| 110 | + * Takes the input from the A2A request |
| 111 | + * Convert the input to ADK input content, and runs the ADK agent |
| 112 | + * Collects output events of the underlying ADK Agent |
| 113 | + * Converts the ADK output events into A2A task updates |
| 114 | + * Publishes the updates back to A2A server via event queue |
| 115 | + """ |
| 116 | + if not context.message: |
| 117 | + raise ValueError('A2A request must have a message') |
| 118 | + |
| 119 | + # for new task, create a task submitted event |
| 120 | + if not context.current_task: |
| 121 | + await event_queue.enqueue_event( |
| 122 | + TaskStatusUpdateEvent( |
| 123 | + taskId=context.task_id, |
| 124 | + status=TaskStatus( |
| 125 | + state=TaskState.submitted, |
| 126 | + message=context.message, |
| 127 | + timestamp=datetime.now(timezone.utc).isoformat(), |
| 128 | + ), |
| 129 | + contextId=context.context_id, |
| 130 | + final=False, |
| 131 | + ) |
| 132 | + ) |
| 133 | + |
| 134 | + # Handle the request and publish updates to the event queue |
| 135 | + try: |
| 136 | + await self._handle_request(context, event_queue) |
| 137 | + except Exception as e: |
| 138 | + logger.error('Error handling A2A request: %s', e, exc_info=True) |
| 139 | + # Publish failure event |
| 140 | + try: |
| 141 | + await event_queue.enqueue_event( |
| 142 | + TaskStatusUpdateEvent( |
| 143 | + taskId=context.task_id, |
| 144 | + status=TaskStatus( |
| 145 | + state=TaskState.failed, |
| 146 | + timestamp=datetime.now(timezone.utc).isoformat(), |
| 147 | + message=Message( |
| 148 | + messageId=str(uuid.uuid4()), |
| 149 | + role=Role.agent, |
| 150 | + parts=[TextPart(text=str(e))], |
| 151 | + ), |
| 152 | + ), |
| 153 | + contextId=context.context_id, |
| 154 | + final=True, |
| 155 | + ) |
| 156 | + ) |
| 157 | + except Exception as enqueue_error: |
| 158 | + logger.error( |
| 159 | + 'Failed to publish failure event: %s', enqueue_error, exc_info=True |
| 160 | + ) |
| 161 | + |
| 162 | + async def _handle_request( |
| 163 | + self, |
| 164 | + context: RequestContext, |
| 165 | + event_queue: EventQueue, |
| 166 | + ): |
| 167 | + # Resolve the runner instance |
| 168 | + runner = await self._resolve_runner() |
| 169 | + |
| 170 | + # Convert the a2a request to ADK run args |
| 171 | + run_args = convert_a2a_request_to_adk_run_args(context) |
| 172 | + |
| 173 | + # ensure the session exists |
| 174 | + session = await self._prepare_session(context, run_args, runner) |
| 175 | + |
| 176 | + # create invocation context |
| 177 | + invocation_context = runner._new_invocation_context( |
| 178 | + session=session, |
| 179 | + new_message=run_args['new_message'], |
| 180 | + run_config=run_args['run_config'], |
| 181 | + ) |
| 182 | + |
| 183 | + # publish the task working event |
| 184 | + await event_queue.enqueue_event( |
| 185 | + TaskStatusUpdateEvent( |
| 186 | + taskId=context.task_id, |
| 187 | + status=TaskStatus( |
| 188 | + state=TaskState.working, |
| 189 | + timestamp=datetime.now(timezone.utc).isoformat(), |
| 190 | + ), |
| 191 | + contextId=context.context_id, |
| 192 | + final=False, |
| 193 | + metadata={ |
| 194 | + _get_adk_metadata_key('app_name'): runner.app_name, |
| 195 | + _get_adk_metadata_key('user_id'): run_args['user_id'], |
| 196 | + _get_adk_metadata_key('session_id'): run_args['session_id'], |
| 197 | + }, |
| 198 | + ) |
| 199 | + ) |
| 200 | + |
| 201 | + task_result_aggregator = TaskResultAggregator() |
| 202 | + async for adk_event in runner.run_async(**run_args): |
| 203 | + for a2a_event in convert_event_to_a2a_events( |
| 204 | + adk_event, invocation_context, context.task_id, context.context_id |
| 205 | + ): |
| 206 | + task_result_aggregator.process_event(a2a_event) |
| 207 | + await event_queue.enqueue_event(a2a_event) |
| 208 | + |
| 209 | + # publish the task result event - this is final |
| 210 | + await event_queue.enqueue_event( |
| 211 | + TaskStatusUpdateEvent( |
| 212 | + taskId=context.task_id, |
| 213 | + status=TaskStatus( |
| 214 | + state=( |
| 215 | + task_result_aggregator.task_state |
| 216 | + if task_result_aggregator.task_state != TaskState.working |
| 217 | + else TaskState.completed |
| 218 | + ), |
| 219 | + timestamp=datetime.now(timezone.utc).isoformat(), |
| 220 | + message=task_result_aggregator.task_status_message, |
| 221 | + ), |
| 222 | + contextId=context.context_id, |
| 223 | + final=True, |
| 224 | + ) |
| 225 | + ) |
| 226 | + |
| 227 | + async def _prepare_session( |
| 228 | + self, context: RequestContext, run_args: dict[str, Any], runner: Runner |
| 229 | + ): |
| 230 | + |
| 231 | + session_id = run_args['session_id'] |
| 232 | + # create a new session if not exists |
| 233 | + user_id = run_args['user_id'] |
| 234 | + session = await runner.session_service.get_session( |
| 235 | + app_name=runner.app_name, |
| 236 | + user_id=user_id, |
| 237 | + session_id=session_id, |
| 238 | + ) |
| 239 | + if session is None: |
| 240 | + session = await runner.session_service.create_session( |
| 241 | + app_name=runner.app_name, |
| 242 | + user_id=user_id, |
| 243 | + state={}, |
| 244 | + session_id=session_id, |
| 245 | + ) |
| 246 | + # Update run_args with the new session_id |
| 247 | + run_args['session_id'] = session.id |
| 248 | + |
| 249 | + return session |
0 commit comments