Skip to content

Commit 630f167

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add a2a agent executor
PiperOrigin-RevId: 775983689
1 parent 2f55de6 commit 630f167

File tree

2 files changed

+1078
-0
lines changed

2 files changed

+1078
-0
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from datetime import datetime
18+
from datetime import timezone
19+
import inspect
20+
import logging
21+
from typing import Any
22+
from typing import Awaitable
23+
from typing import Callable
24+
from typing import Optional
25+
import uuid
26+
27+
from a2a.server.agent_execution import AgentExecutor
28+
from a2a.server.agent_execution.context import RequestContext
29+
from a2a.server.events.event_queue import EventQueue
30+
from a2a.types import Message
31+
from a2a.types import Role
32+
from a2a.types import TaskState
33+
from a2a.types import TaskStatus
34+
from a2a.types import TaskStatusUpdateEvent
35+
from a2a.types import TextPart
36+
from google.adk.runners import Runner
37+
from pydantic import BaseModel
38+
from typing_extensions import override
39+
40+
from ...utils.feature_decorator import working_in_progress
41+
from ..converters.event_converter import convert_event_to_a2a_events
42+
from ..converters.request_converter import convert_a2a_request_to_adk_run_args
43+
from ..converters.utils import _get_adk_metadata_key
44+
from .task_result_aggregator import TaskResultAggregator
45+
46+
logger = logging.getLogger('google_adk.' + __name__)
47+
48+
49+
@working_in_progress
50+
class A2aAgentExecutorConfig(BaseModel):
51+
"""Configuration for the A2aAgentExecutor."""
52+
53+
pass
54+
55+
56+
@working_in_progress
57+
class A2aAgentExecutor(AgentExecutor):
58+
"""An AgentExecutor that runs an ADK Agent against an A2A request and
59+
publishes updates to an event queue.
60+
"""
61+
62+
def __init__(
63+
self,
64+
*,
65+
runner: Runner | Callable[..., Runner | Awaitable[Runner]],
66+
config: Optional[A2aAgentExecutorConfig] = None,
67+
):
68+
super().__init__()
69+
self._runner = runner
70+
self._config = config
71+
72+
async def _resolve_runner(self) -> Runner:
73+
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
74+
# If already resolved and cached, return it
75+
if isinstance(self._runner, Runner):
76+
return self._runner
77+
if callable(self._runner):
78+
# Call the function to get the runner
79+
result = self._runner()
80+
81+
# Handle async callables
82+
if inspect.iscoroutine(result):
83+
resolved_runner = await result
84+
else:
85+
resolved_runner = result
86+
87+
# Cache the resolved runner for future calls
88+
self._runner = resolved_runner
89+
return resolved_runner
90+
91+
raise TypeError(
92+
'Runner must be a Runner instance or a callable that returns a'
93+
f' Runner, got {type(self._runner)}'
94+
)
95+
96+
@override
97+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
98+
"""Cancel the execution."""
99+
# TODO: Implement proper cancellation logic if needed
100+
raise NotImplementedError('Cancellation is not supported')
101+
102+
@override
103+
async def execute(
104+
self,
105+
context: RequestContext,
106+
event_queue: EventQueue,
107+
):
108+
"""Executes an A2A request and publishes updates to the event queue
109+
specified. It runs as following:
110+
* Takes the input from the A2A request
111+
* Convert the input to ADK input content, and runs the ADK agent
112+
* Collects output events of the underlying ADK Agent
113+
* Converts the ADK output events into A2A task updates
114+
* Publishes the updates back to A2A server via event queue
115+
"""
116+
if not context.message:
117+
raise ValueError('A2A request must have a message')
118+
119+
# for new task, create a task submitted event
120+
if not context.current_task:
121+
await event_queue.enqueue_event(
122+
TaskStatusUpdateEvent(
123+
taskId=context.task_id,
124+
status=TaskStatus(
125+
state=TaskState.submitted,
126+
message=context.message,
127+
timestamp=datetime.now(timezone.utc).isoformat(),
128+
),
129+
contextId=context.context_id,
130+
final=False,
131+
)
132+
)
133+
134+
# Handle the request and publish updates to the event queue
135+
try:
136+
await self._handle_request(context, event_queue)
137+
except Exception as e:
138+
logger.error('Error handling A2A request: %s', e, exc_info=True)
139+
# Publish failure event
140+
try:
141+
await event_queue.enqueue_event(
142+
TaskStatusUpdateEvent(
143+
taskId=context.task_id,
144+
status=TaskStatus(
145+
state=TaskState.failed,
146+
timestamp=datetime.now(timezone.utc).isoformat(),
147+
message=Message(
148+
messageId=str(uuid.uuid4()),
149+
role=Role.agent,
150+
parts=[TextPart(text=str(e))],
151+
),
152+
),
153+
contextId=context.context_id,
154+
final=True,
155+
)
156+
)
157+
except Exception as enqueue_error:
158+
logger.error(
159+
'Failed to publish failure event: %s', enqueue_error, exc_info=True
160+
)
161+
162+
async def _handle_request(
163+
self,
164+
context: RequestContext,
165+
event_queue: EventQueue,
166+
):
167+
# Resolve the runner instance
168+
runner = await self._resolve_runner()
169+
170+
# Convert the a2a request to ADK run args
171+
run_args = convert_a2a_request_to_adk_run_args(context)
172+
173+
# ensure the session exists
174+
session = await self._prepare_session(context, run_args, runner)
175+
176+
# create invocation context
177+
invocation_context = runner._new_invocation_context(
178+
session=session,
179+
new_message=run_args['new_message'],
180+
run_config=run_args['run_config'],
181+
)
182+
183+
# publish the task working event
184+
await event_queue.enqueue_event(
185+
TaskStatusUpdateEvent(
186+
taskId=context.task_id,
187+
status=TaskStatus(
188+
state=TaskState.working,
189+
timestamp=datetime.now(timezone.utc).isoformat(),
190+
),
191+
contextId=context.context_id,
192+
final=False,
193+
metadata={
194+
_get_adk_metadata_key('app_name'): runner.app_name,
195+
_get_adk_metadata_key('user_id'): run_args['user_id'],
196+
_get_adk_metadata_key('session_id'): run_args['session_id'],
197+
},
198+
)
199+
)
200+
201+
task_result_aggregator = TaskResultAggregator()
202+
async for adk_event in runner.run_async(**run_args):
203+
for a2a_event in convert_event_to_a2a_events(
204+
adk_event, invocation_context, context.task_id, context.context_id
205+
):
206+
task_result_aggregator.process_event(a2a_event)
207+
await event_queue.enqueue_event(a2a_event)
208+
209+
# publish the task result event - this is final
210+
await event_queue.enqueue_event(
211+
TaskStatusUpdateEvent(
212+
taskId=context.task_id,
213+
status=TaskStatus(
214+
state=(
215+
task_result_aggregator.task_state
216+
if task_result_aggregator.task_state != TaskState.working
217+
else TaskState.completed
218+
),
219+
timestamp=datetime.now(timezone.utc).isoformat(),
220+
message=task_result_aggregator.task_status_message,
221+
),
222+
contextId=context.context_id,
223+
final=True,
224+
)
225+
)
226+
227+
async def _prepare_session(
228+
self, context: RequestContext, run_args: dict[str, Any], runner: Runner
229+
):
230+
231+
session_id = run_args['session_id']
232+
# create a new session if not exists
233+
user_id = run_args['user_id']
234+
session = await runner.session_service.get_session(
235+
app_name=runner.app_name,
236+
user_id=user_id,
237+
session_id=session_id,
238+
)
239+
if session is None:
240+
session = await runner.session_service.create_session(
241+
app_name=runner.app_name,
242+
user_id=user_id,
243+
state={},
244+
session_id=session_id,
245+
)
246+
# Update run_args with the new session_id
247+
run_args['session_id'] = session.id
248+
249+
return session

0 commit comments

Comments
 (0)