diff --git a/contributing/samples/array_iterator_agent/README.md b/contributing/samples/array_iterator_agent/README.md new file mode 100644 index 000000000..c46b0ae4c --- /dev/null +++ b/contributing/samples/array_iterator_agent/README.md @@ -0,0 +1,208 @@ +# ArrayIteratorAgent Sample + +This sample demonstrates how to use the `ArrayIteratorAgent` for processing arrays of data with a single sub-agent. + +## Overview + +The `ArrayIteratorAgent` is designed to: +- **Iterate over arrays** in session state (supports nested keys with dot notation) +- **Apply a single sub-agent** to each array item +- **Collect results** optionally into an output array +- **Handle escalation** to stop processing when needed + +## Key Features + +### πŸ”§ **Single Sub-Agent Focus** +- Accepts exactly **one sub-agent** (enforced by validation) +- For complex processing, use `SequentialAgent` or `ParallelAgent` as the single sub-agent + +### πŸ—‚οΈ **Nested Key Support** +- Array key: `"documents"` or `"user.profile.documents"` +- Output key: `"results"` or `"processed.batch_results"` + +### πŸ“Š **Result Collection** +- Automatic collection of sub-agent results when `output_key` is specified +- Results stored as array in session state + +### ⚑ **Escalation Handling** +- Stops iteration when sub-agent escalates +- Graceful cleanup of temporary state + +## Usage Examples + +### 1. Simple Document Processing + +```python +from google.adk.agents import ArrayIteratorAgent, LlmAgent + +# Document analyzer +analyzer = LlmAgent( + name="document_analyzer", + model=LiteLLMConnection(model_name="gpt-4o-mini"), + instruction="Analyze document in {current_doc}", + output_key="analysis" +) + +# Array processor +processor = ArrayIteratorAgent( + name="doc_processor", + array_key="documents", # Array in session state + item_key="current_doc", # Key for current item + output_key="analyses", # Collect results + sub_agents=[analyzer] # Single sub-agent +) +``` + +**Session State:** +```json +{ + "documents": [ + {"title": "Doc1", "content": "..."}, + {"title": "Doc2", "content": "..."} + ] +} +``` + +**Result:** +```json +{ + "documents": [/* original docs */], + "analyses": ["Analysis of Doc1", "Analysis of Doc2"] +} +``` + +### 2. Nested Data Processing + +```python +# Process nested customer array +customer_processor = ArrayIteratorAgent( + name="customer_processor", + array_key="company.customers", # Nested array access + item_key="current_customer", + output_key="company.processed", # Nested output + sub_agents=[customer_analyzer] +) +``` + +**Session State:** +```json +{ + "company": { + "name": "TechCorp", + "customers": [ + {"name": "Alice", "spend": 12000}, + {"name": "Bob", "spend": 7500} + ] + } +} +``` + +### 3. Complex Pipeline Processing + +```python +# Multi-step pipeline as single sub-agent +pipeline = SequentialAgent( + name="processing_pipeline", + sub_agents=[extractor, validator, transformer] +) + +# Use pipeline in array iterator +batch_processor = ArrayIteratorAgent( + name="batch_processor", + array_key="raw_data", + item_key="current_item", + output_key="processed_batch", + sub_agents=[pipeline] # Pipeline as single sub-agent +) +``` + +### 4. Without Result Collection + +```python +# Process without collecting results +notifier = ArrayIteratorAgent( + name="notification_sender", + array_key="users", + item_key="current_user", + # No output_key - don't collect results + sub_agents=[notification_agent] +) +``` + +## Configuration Options + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `name` | `str` | βœ… | Agent name | +| `array_key` | `str` | βœ… | Path to array (supports `dot.notation`) | +| `item_key` | `str` | Optional | Key for current item (default: `"current_item"`) | +| `output_key` | `str` | Optional | Key to store results array | +| `sub_agents` | `list[BaseAgent]` | βœ… | **Exactly one sub-agent** | + +## Error Handling + +### Validation Errors +```python +# ❌ No sub-agents +ArrayIteratorAgent(name="bad", array_key="items", sub_agents=[]) +# ValueError: ArrayIteratorAgent requires exactly one sub-agent + +# ❌ Multiple sub-agents +ArrayIteratorAgent(name="bad", array_key="items", sub_agents=[agent1, agent2]) +# ValueError: ArrayIteratorAgent accepts only one sub-agent, but 2 were provided +``` + +### Runtime Errors +```python +# ❌ Missing array key +# ValueError: Array key 'missing_key' not found or invalid + +# ❌ Non-array value +# TypeError: Value at 'not_array' is not a list. Got str +``` + +## Best Practices + +### βœ… **Do:** +- Use single sub-agent pattern for focused iteration +- Leverage nested keys for complex data structures +- Use `SequentialAgent`/`ParallelAgent` as sub-agent for complex workflows +- Handle escalation gracefully in sub-agents + +### ❌ **Don't:** +- Try to add multiple sub-agents directly +- Assume arrays are always non-empty +- Forget to handle missing keys in session state +- Mix iteration logic with processing logic + +## Sample Data + +The sample includes realistic test data: + +```python +SAMPLE_DATA = { + "documents": [/* document objects */], + "company": { + "customers": [/* customer objects */] + }, + "raw_data": [/* processing items */], + "items_to_process": [/* items with error cases */] +} +``` + +## Running the Sample + +```bash +cd adk-python/contributing/samples/array_iterator_agent +python agent.py +``` + +This will show the different ArrayIteratorAgent configurations available. + +## Related Agents + +- **`LoopAgent`**: Fixed iteration count +- **`SequentialAgent`**: Sequential sub-agent execution +- **`ParallelAgent`**: Parallel sub-agent execution + +The `ArrayIteratorAgent` complements these by providing **data-driven iteration** over arrays. \ No newline at end of file diff --git a/contributing/samples/array_iterator_agent/__init__.py b/contributing/samples/array_iterator_agent/__init__.py new file mode 100644 index 000000000..4985f01b6 --- /dev/null +++ b/contributing/samples/array_iterator_agent/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/contributing/samples/array_iterator_agent/agent.py b/contributing/samples/array_iterator_agent/agent.py new file mode 100644 index 000000000..1a4387513 --- /dev/null +++ b/contributing/samples/array_iterator_agent/agent.py @@ -0,0 +1,472 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample demonstrating ArrayIteratorAgent usage patterns with realistic data flow.""" + +from google.adk.agents import LlmAgent, SequentialAgent +from pydantic import BaseModel, Field +from typing import List, Dict, Any + +# Import ArrayIteratorAgent - try package first, fallback to local +try: + from google.adk.agents.array_iterator_agent import ArrayIteratorAgent +except ImportError: + # If package import fails, try local import + import sys + import os + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src')) + from google.adk.agents.array_iterator_agent import ArrayIteratorAgent + + +# === Pydantic Models for Structured Output === + +class DocumentMetadata(BaseModel): + """Structured output for document discovery.""" + title: str = Field(..., description="Document title") + content: str = Field(..., description="Document content") + url: str = Field(..., description="Document URL") + importance: int = Field(..., description="Importance score 1-10") + +class DocumentAnalysis(BaseModel): + """Structured output for document analysis.""" + title: str = Field(..., description="Extracted title") + summary: str = Field(..., description="2-sentence summary") + key_topics: List[str] = Field(..., description="List of key topics") + sentiment: str = Field(..., description="Overall sentiment") + +class CustomerData(BaseModel): + """Structured output for customer data.""" + name: str = Field(..., description="Customer name") + email: str = Field(..., description="Customer email") + annual_spend: float = Field(..., description="Annual spending amount") + tier: str = Field(..., description="Customer tier") + + +# === Agent Workflows with Realistic Data Flow === + +def create_document_discovery_and_processing_workflow(): + """ + REALISTIC WORKFLOW: Document discovery β†’ Processing + + 1. Document Finder Agent discovers documents (produces array) + 2. ArrayIteratorAgent processes each document + """ + + # Step 1: Agent that discovers/fetches documents and produces structured array + document_finder = LlmAgent( + name="document_finder", + model="gemini-2.0-flash", + instruction=""" + Based on the user's query in {user_query}, find and list relevant documents. + + For each document, provide: + - title: Document title + - content: Brief content excerpt + - url: Document URL + - importance: Relevance score 1-10 + + Return as a JSON array of document objects. + """, + output_schema=List[DocumentMetadata], # Produces structured array + output_key="discovered_documents" # Stored in session state + ) + + # Step 2: Document analyzer (processes individual documents) + document_analyzer = LlmAgent( + name="document_analyzer", + model="gemini-2.0-flash", + instruction=""" + Analyze the document provided in {current_document}. + + Extract: + - title: Clean document title + - summary: Exactly 2 sentences summarizing the content + - key_topics: List of 3-5 main topics/keywords + - sentiment: positive/negative/neutral + """, + output_schema=DocumentAnalysis, # Structured output per document + output_key="document_analysis" + ) + + # Step 3: Array iterator processes the discovered documents + document_processor = ArrayIteratorAgent( + name="document_processor", + array_key="discovered_documents", # Array from document_finder + item_key="current_document", # Current doc for analyzer + output_key="document_analyses", # Collected analyses + sub_agents=[document_analyzer] + ) + + # Step 4: Complete workflow + workflow = SequentialAgent( + name="document_workflow", + description="Discovers documents then processes each one", + sub_agents=[ + document_finder, # Produces array in session state + document_processor # Processes the array + ] + ) + + return workflow + + +def create_customer_segmentation_workflow(): + """ + REALISTIC WORKFLOW: Customer data ingestion β†’ Segmentation + + 1. CRM Data Agent fetches customer data (produces nested array) + 2. ArrayIteratorAgent processes each customer for segmentation + """ + + # Step 1: Agent that fetches customer data from CRM/database + crm_data_agent = LlmAgent( + name="crm_data_agent", + model="gemini-2.0-flash", + instruction=""" + Based on the company ID in {company_id}, fetch customer data from CRM. + + Return company info with customer array: + { + "company": { + "name": "Company Name", + "industry": "Industry Type", + "customers": [ + {"name": "Customer Name", "email": "email", "annual_spend": 0000} + ] + } + } + + Include 3-5 customers with realistic data. + """, + output_key="company_data" # Creates nested structure + ) + + # Step 2: Customer segmentation processor + customer_segmenter = LlmAgent( + name="customer_segmenter", + model="gemini-2.0-flash", + instruction=""" + Process the customer data in {current_customer}. + + Analyze and determine: + - Tier level based on annual_spend: + * VIP: > $10,000 annual spend + * Premium: $5,000-$10,000 annual spend + * Standard: < $5,000 annual spend + - Personalized greeting + - Recommended actions + + Return structured customer analysis. + """, + output_schema=CustomerData, # Structured output per customer + output_key="customer_segment" + ) + + # Step 3: Array iterator processes customers from nested path + customer_processor = ArrayIteratorAgent( + name="customer_processor", + array_key="company_data.company.customers", # Nested array access + item_key="current_customer", + output_key="company_data.company.segmented_customers", # Nested output + sub_agents=[customer_segmenter] + ) + + # Step 4: Complete workflow + segmentation_workflow = SequentialAgent( + name="customer_segmentation_workflow", + description="Fetches customer data then segments each customer", + sub_agents=[ + crm_data_agent, # Produces nested data structure + customer_processor # Processes nested customer array + ] + ) + + return segmentation_workflow + + +def create_data_ingestion_and_processing_workflow(): + """ + REALISTIC WORKFLOW: Data ingestion β†’ ETL processing + + 1. Data Collector Agent fetches raw data (produces array) + 2. ArrayIteratorAgent processes each record through ETL pipeline + """ + + # Step 1: Data collector that fetches raw events/records + data_collector = LlmAgent( + name="data_collector", + model="gemini-2.0-flash", + instruction=""" + Based on the data source in {data_source}, collect raw data records. + + Return array of raw data records: + [ + {"source": "API", "data": "raw_event_data", "timestamp": "2024-01-01"}, + {"source": "DB", "data": "database_record", "timestamp": "2024-01-02"} + ] + + Include 5-10 records with various sources and realistic data. + """, + output_key="raw_data" # Produces array for processing + ) + + # ETL Pipeline Steps + + # Step 2a: Extract data + extractor = LlmAgent( + name="data_extractor", + model="gemini-2.0-flash", + instruction=""" + Extract structured data from {current_item}. + Parse the raw data and extract key fields like ID, type, value, etc. + """, + output_key="extracted_data" + ) + + # Step 2b: Validate data + validator = LlmAgent( + name="data_validator", + model="gemini-2.0-flash", + instruction=""" + Validate the extracted data in {extracted_data}. + Check for completeness, format, and business rules. + Return validation status and cleaned data. + """, + output_key="validation_result" + ) + + # Step 2c: Transform data + transformer = LlmAgent( + name="data_transformer", + model="gemini-2.0-flash", + instruction=""" + Transform validated data {validation_result} into final format. + Apply business logic, standardize formats, enrich with metadata. + """, + output_key="transformed_data" + ) + + # Step 3: Sequential ETL pipeline + etl_pipeline = SequentialAgent( + name="etl_pipeline", + description="Extract β†’ Transform β†’ Load pipeline for single record", + sub_agents=[extractor, validator, transformer] + ) + + # Step 4: Array iterator applies ETL pipeline to each raw record + batch_processor = ArrayIteratorAgent( + name="batch_processor", + array_key="raw_data", # Array from data_collector + item_key="current_item", # Current record for ETL + output_key="processed_batch", # Final processed results + sub_agents=[etl_pipeline] # ETL pipeline as single sub-agent + ) + + # Step 5: Complete data processing workflow + data_workflow = SequentialAgent( + name="data_processing_workflow", + description="Collects raw data then processes each record through ETL", + sub_agents=[ + data_collector, # Produces raw data array + batch_processor # Processes each record through ETL + ] + ) + + return data_workflow + + +def create_quality_assurance_workflow(): + """ + REALISTIC WORKFLOW: Content generation β†’ Quality check + + 1. Content Generator creates articles (produces array) + 2. ArrayIteratorAgent runs QA checks on each article + 3. Handles escalation when quality issues are found + """ + + # Step 1: Content generator that creates articles + content_generator = LlmAgent( + name="content_generator", + model="gemini-2.0-flash", + instruction=""" + Based on the topic list in {topics}, generate article drafts. + + Return array of article objects: + [ + {"title": "Article Title", "content": "Article content...", "status": "draft"}, + {"title": "Another Title", "content": "More content...", "status": "draft"} + ] + + Include 4-5 articles. Occasionally include problematic content with + "PLAGIARISM" or "LOW_QUALITY" markers to test QA escalation. + """, + output_key="generated_articles" # Produces article array + ) + + # Step 2: Quality assurance checker + qa_checker = LlmAgent( + name="qa_checker", + model="gemini-2.0-flash", + instruction=""" + Review the article in {current_article} for quality issues. + + Check for: + - Plagiarism indicators (if content contains "PLAGIARISM") + - Quality issues (if content contains "LOW_QUALITY") + - Grammar and coherence + + If serious issues found (PLAGIARISM/LOW_QUALITY), escalate to stop processing. + Otherwise, return quality score and recommendations. + """, + output_key="qa_result" + ) + + # Step 3: Array iterator with escalation handling + quality_processor = ArrayIteratorAgent( + name="quality_processor", + array_key="generated_articles", # Array from content_generator + item_key="current_article", # Current article for QA + output_key="qa_results", # QA results (until escalation) + sub_agents=[qa_checker] + ) + + # Step 4: Complete QA workflow + qa_workflow = SequentialAgent( + name="content_qa_workflow", + description="Generates content then runs QA checks with escalation", + sub_agents=[ + content_generator, # Produces article array + quality_processor # QA checks each article (stops on issues) + ] + ) + + return qa_workflow + + +# === Session State Examples === +# In real scenarios, this data comes from previous agents, not hardcoded constants! + +def create_standalone_iterator_example(): + """ + EXAMPLE: Using ArrayIteratorAgent with pre-populated session state + (For testing when you already have array data) + """ + + # Simple processor for when data is already in session state + simple_processor = LlmAgent( + name="simple_processor", + model="gemini-2.0-flash", + instruction="Process the item in {current_item} and return a summary", + output_key="item_summary" + ) + + # Array iterator (assumes data already in session state) + standalone_iterator = ArrayIteratorAgent( + name="standalone_iterator", + array_key="existing_data", # Must exist in session state + item_key="current_item", + output_key="processing_results", + sub_agents=[simple_processor] + ) + + return standalone_iterator + + +# === How Session State Gets Populated === + +EXAMPLE_SESSION_STATES = { + "document_workflow": { + # Initial state - user provides query + "user_query": "Find articles about AI in healthcare", + # After document_finder runs: + "discovered_documents": [ + {"title": "AI in Medical Diagnosis", "content": "AI systems are...", "url": "https://...", "importance": 9}, + {"title": "ML for Drug Discovery", "content": "Machine learning...", "url": "https://...", "importance": 8} + ], + # After ArrayIteratorAgent runs: + "document_analyses": [ + {"title": "AI in Medical Diagnosis", "summary": "...", "key_topics": ["AI", "medical"], "sentiment": "positive"}, + {"title": "ML for Drug Discovery", "summary": "...", "key_topics": ["ML", "pharma"], "sentiment": "positive"} + ] + }, + + "customer_workflow": { + # Initial state + "company_id": "CORP-123", + # After crm_data_agent runs: + "company_data": { + "company": { + "name": "TechCorp Inc", + "industry": "Software", + "customers": [ + {"name": "Alice Johnson", "email": "alice@example.com", "annual_spend": 15000}, + {"name": "Bob Smith", "email": "bob@example.com", "annual_spend": 7500} + ], + # After ArrayIteratorAgent runs: + "segmented_customers": [ + {"name": "Alice Johnson", "email": "alice@example.com", "annual_spend": 15000, "tier": "VIP"}, + {"name": "Bob Smith", "email": "bob@example.com", "annual_spend": 7500, "tier": "Premium"} + ] + } + } + }, + + "data_processing_workflow": { + # Initial state + "data_source": "production_logs", + # After data_collector runs: + "raw_data": [ + {"source": "API", "data": "user_login_event", "timestamp": "2024-01-01T10:00:00Z"}, + {"source": "DB", "data": "order_created", "timestamp": "2024-01-01T10:05:00Z"} + ], + # After ArrayIteratorAgent runs: + "processed_batch": [ + {"event_type": "login", "user_id": "123", "processed_at": "2024-01-01T10:00:00Z", "status": "valid"}, + {"event_type": "order", "order_id": "456", "processed_at": "2024-01-01T10:05:00Z", "status": "valid"} + ] + } +} + + +if __name__ == "__main__": + print("πŸ”„ ArrayIteratorAgent: Realistic Workflow Examples") + print("=" * 60) + + print("\nπŸ“‹ Available Workflows:") + print("1. Document Discovery & Processing:", create_document_discovery_and_processing_workflow().name) + print("2. Customer Segmentation:", create_customer_segmentation_workflow().name) + print("3. Data Ingestion & ETL:", create_data_ingestion_and_processing_workflow().name) + print("4. Content QA Pipeline:", create_quality_assurance_workflow().name) + print("5. Standalone Iterator:", create_standalone_iterator_example().name) + + print("\nπŸ”„ How ArrayIteratorAgent Works:") + print("β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”") + print("β”‚ Agent A │───▢│ Session State │───▢│ ArrayIterator β”‚") + print("β”‚ (Produces Array)β”‚ β”‚ {array_key: [...]}β”‚ β”‚ (Processes Each)β”‚") + print("β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜") + + print("\nπŸ“Š Session State Evolution Example:") + print("Initial: {user_query: 'Find AI articles'}") + print("After Agent A: {user_query: '...', discovered_docs: [doc1, doc2, doc3]}") + print("After Iterator: {user_query: '...', discovered_docs: [...], analyses: [analysis1, analysis2, analysis3]}") + + print("\nπŸš€ Key Benefits:") + print("βœ… Structured data flow between agents") + print("βœ… Automatic result collection") + print("βœ… Nested key support for complex data") + print("βœ… Escalation handling for quality control") + print("βœ… Reusable iteration patterns") + + print("\nπŸ’‘ Usage: Combine agents in SequentialAgent for complete workflows") + print(" Example: DataCollector β†’ ArrayIteratorAgent β†’ ResultProcessor") \ No newline at end of file diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index e1f773c47..ff11f5d9a 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .array_iterator_agent import ArrayIteratorAgent from .base_agent import BaseAgent from .live_request_queue import LiveRequest from .live_request_queue import LiveRequestQueue @@ -24,6 +25,7 @@ __all__ = [ 'Agent', + 'ArrayIteratorAgent', 'BaseAgent', 'LlmAgent', 'LoopAgent', diff --git a/src/google/adk/agents/array_iterator_agent.py b/src/google/adk/agents/array_iterator_agent.py new file mode 100644 index 000000000..91f47b5f2 --- /dev/null +++ b/src/google/adk/agents/array_iterator_agent.py @@ -0,0 +1,235 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Array iterator agent implementation.""" + +from __future__ import annotations + +import logging +from typing import Any, AsyncGenerator, Optional + +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import override + +from ..agents.invocation_context import InvocationContext +from ..events.event import Event +from .base_agent import BaseAgent + +logger = logging.getLogger(__name__) + + +def _get_nested_value(data: dict[str, Any], key_path: str) -> Any: + """Get value from nested dictionary using dot notation. + + Args: + data: The dictionary to search in. + key_path: The key path using dot notation (e.g., 'user.profile.name'). + + Returns: + The value at the specified path. + + Raises: + KeyError: If the key path is not found. + TypeError: If trying to access a key on a non-dict value. + """ + if not key_path: + raise KeyError("Key path cannot be empty") + + keys = key_path.split('.') + current = data + + for i, key in enumerate(keys): + if not isinstance(current, dict): + path_so_far = '.'.join(keys[:i]) + raise TypeError( + f"Cannot access key '{key}' on non-dict value at path '{path_so_far}'" + ) + + if key not in current: + path_so_far = '.'.join(keys[:i+1]) + raise KeyError(f"Key path '{path_so_far}' not found") + + current = current[key] + + return current + + +def _set_nested_value(data: dict[str, Any], key_path: str, value: Any) -> None: + """Set value in nested dictionary using dot notation. + + Args: + data: The dictionary to modify. + key_path: The key path using dot notation (e.g., 'user.profile.name'). + value: The value to set. + + Raises: + ValueError: If the key path is invalid. + TypeError: If trying to set a key on a non-dict value. + """ + if not key_path: + raise ValueError("Key path cannot be empty") + + keys = key_path.split('.') + current = data + + # Navigate to the parent of the final key + for i, key in enumerate(keys[:-1]): + if key not in current: + current[key] = {} + elif not isinstance(current[key], dict): + path_so_far = '.'.join(keys[:i+1]) + raise TypeError( + f"Cannot set nested key on non-dict value at path '{path_so_far}'" + ) + current = current[key] + + # Set the final value + current[keys[-1]] = value + + +class ArrayIteratorAgent(BaseAgent): + """Agent that iterates over an array and applies a single sub-agent to each item. + + This agent focuses solely on iteration - it takes an array from session state, + applies one sub-agent to each item, and optionally collects the results. + + Example: + ```python + processor = ArrayIteratorAgent( + name="document_processor", + array_key="documents", # Can be nested: "user.documents" + item_key="current_doc", + output_key="processed_results", + sub_agents=[document_analyzer] # Exactly one sub-agent + ) + ``` + """ + + model_config = ConfigDict(extra="forbid", exclude_none=True) + + array_key: str = Field(..., description="Path to array in session state (supports dot notation)") + item_key: str = Field(default="current_item", description="Key to store current item in session state") + output_key: Optional[str] = Field(default=None, description="Key to store collected results array") + + @model_validator(mode='after') + def _validate_single_sub_agent(self) -> 'ArrayIteratorAgent': + """Validate that exactly one sub-agent is provided.""" + if len(self.sub_agents) == 0: + raise ValueError("ArrayIteratorAgent requires exactly one sub-agent") + + if len(self.sub_agents) > 1: + raise ValueError( + f"ArrayIteratorAgent accepts only one sub-agent, but {len(self.sub_agents)} were provided. " + f"If you need multiple agents, use SequentialAgent or ParallelAgent as the single sub-agent." + ) + + return self + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Execute the array iteration logic.""" + + # Get the array from session state + try: + state_dict = ctx.session.state.to_dict() + array_data = _get_nested_value(state_dict, self.array_key) + except (KeyError, TypeError) as e: + logger.error(f"Failed to get array from key '{self.array_key}': {e}") + raise ValueError(f"Array key '{self.array_key}' not found or invalid: {e}") + + # Validate that we have an array + if not isinstance(array_data, list): + raise TypeError( + f"Value at '{self.array_key}' is not a list. Got {type(array_data).__name__}" + ) + + if not array_data: + logger.info(f"Array at '{self.array_key}' is empty, skipping iteration") + return + + logger.info(f"Starting iteration over {len(array_data)} items from '{self.array_key}'") + + # Collect results if output_key is specified + results = [] if self.output_key else None + sub_agent = self.sub_agents[0] + + # Store original item_key value to restore later + original_item_value = ctx.session.state.get(self.item_key) + + try: + # Iterate over each item in the array + for i, item in enumerate(array_data): + logger.debug(f"Processing item {i+1}/{len(array_data)}") + + # Inject current item into session state + ctx.session.state[self.item_key] = item + + # Execute sub-agent for this item + item_results = [] + async for event in sub_agent.run_async(ctx): + yield event + item_results.append(event) + + # Collect result if output_key is specified + if self.output_key: + # Get the last event's content as the result + if item_results: + last_event = item_results[-1] + if hasattr(last_event, 'content') and last_event.content: + results.append(last_event.content) + else: + results.append(None) + else: + results.append(None) + + # Check for escalation + if item_results and any(event.actions.escalate for event in item_results): + logger.info(f"Sub-agent escalated on item {i+1}, stopping iteration") + break + + finally: + # Restore original item_key value + if original_item_value is not None: + ctx.session.state[self.item_key] = original_item_value + else: + # Remove the item key if it wasn't there originally + if self.item_key in ctx.session.state: + del ctx.session.state[self.item_key] + + # Store results if output_key is specified + if self.output_key and results is not None: + try: + # For simple keys, use direct assignment + if '.' not in self.output_key: + ctx.session.state[self.output_key] = results + else: + # For nested keys, we need to work with the state dict + state_dict = ctx.session.state.to_dict() + _set_nested_value(state_dict, self.output_key, results) + # Update the session state with the modified dict + ctx.session.state.update(state_dict) + logger.info(f"Stored {len(results)} results in '{self.output_key}'") + except (ValueError, TypeError) as e: + logger.error(f"Failed to store results in '{self.output_key}': {e}") + raise ValueError(f"Cannot store results in output key '{self.output_key}': {e}") + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Live implementation for ArrayIteratorAgent.""" + raise NotImplementedError('Live mode is not supported for ArrayIteratorAgent yet.') + yield # AsyncGenerator requires having at least one yield statement \ No newline at end of file diff --git a/tests/unittests/agents/test_array_iterator_agent.py b/tests/unittests/agents/test_array_iterator_agent.py new file mode 100644 index 000000000..e37b1d4c4 --- /dev/null +++ b/tests/unittests/agents/test_array_iterator_agent.py @@ -0,0 +1,397 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ArrayIteratorAgent.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from typing import AsyncGenerator + +from google.adk.agents.array_iterator_agent import ArrayIteratorAgent, _get_nested_value, _set_nested_value +from google.adk.agents.base_agent import BaseAgent +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.state import State + + +class MockAgent(BaseAgent): + """Mock agent for testing.""" + + def __init__(self, name: str): + super().__init__(name=name) + self.call_count = 0 + self.yielded_events = [] + + async def _run_async_impl(self, ctx) -> AsyncGenerator[Event, None]: + self.call_count += 1 + # Create a mock event + event = MagicMock(spec=Event) + event.content = f"processed_{ctx.session.state.get('current_item', 'unknown')}" + event.actions = MagicMock(spec=EventActions) + event.actions.escalate = False + self.yielded_events.append(event) + yield event + + async def _run_live_impl(self, ctx) -> AsyncGenerator[Event, None]: + yield # AsyncGenerator requires at least one yield + + +class MockEscalatingAgent(BaseAgent): + """Mock agent that escalates after processing first item.""" + + def __init__(self, name: str): + super().__init__(name=name) + self.call_count = 0 + + async def _run_async_impl(self, ctx) -> AsyncGenerator[Event, None]: + self.call_count += 1 + event = MagicMock(spec=Event) + event.content = f"processed_{ctx.session.state.get('current_item', 'unknown')}" + event.actions = MagicMock(spec=EventActions) + # Escalate if this is the second call + event.actions.escalate = self.call_count >= 2 + yield event + + async def _run_live_impl(self, ctx) -> AsyncGenerator[Event, None]: + yield # AsyncGenerator requires at least one yield + + +class TestNestedKeyUtils: + """Test the nested key utility functions.""" + + def test_get_nested_value_simple(self): + data = {"key": "value"} + assert _get_nested_value(data, "key") == "value" + + def test_get_nested_value_nested(self): + data = {"user": {"profile": {"name": "John"}}} + assert _get_nested_value(data, "user.profile.name") == "John" + + def test_get_nested_value_missing_key(self): + data = {"key": "value"} + with pytest.raises(KeyError, match="Key path 'missing' not found"): + _get_nested_value(data, "missing") + + def test_get_nested_value_missing_nested_key(self): + data = {"user": {"profile": {}}} + with pytest.raises(KeyError, match="Key path 'user.profile.name' not found"): + _get_nested_value(data, "user.profile.name") + + def test_get_nested_value_non_dict(self): + data = {"user": "not_a_dict"} + with pytest.raises(TypeError, match="Cannot access key 'profile' on non-dict value"): + _get_nested_value(data, "user.profile.name") + + def test_get_nested_value_empty_key(self): + data = {"key": "value"} + with pytest.raises(KeyError, match="Key path cannot be empty"): + _get_nested_value(data, "") + + def test_set_nested_value_simple(self): + data = {} + _set_nested_value(data, "key", "value") + assert data == {"key": "value"} + + def test_set_nested_value_nested(self): + data = {} + _set_nested_value(data, "user.profile.name", "John") + assert data == {"user": {"profile": {"name": "John"}}} + + def test_set_nested_value_existing_path(self): + data = {"user": {"profile": {"age": 30}}} + _set_nested_value(data, "user.profile.name", "John") + assert data == {"user": {"profile": {"age": 30, "name": "John"}}} + + def test_set_nested_value_non_dict_conflict(self): + data = {"user": "not_a_dict"} + with pytest.raises(TypeError, match="Cannot set nested key on non-dict value"): + _set_nested_value(data, "user.profile.name", "John") + + def test_set_nested_value_empty_key(self): + data = {} + with pytest.raises(ValueError, match="Key path cannot be empty"): + _set_nested_value(data, "", "value") + + +class TestArrayIteratorAgent: + """Test the ArrayIteratorAgent class.""" + + def test_init_valid_single_agent(self): + """Test initialization with a single sub-agent.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[sub_agent] + ) + assert agent.name == "iterator" + assert agent.array_key == "items" + assert agent.item_key == "current_item" # default + assert agent.output_key is None # default + assert len(agent.sub_agents) == 1 + + def test_init_no_sub_agents(self): + """Test initialization fails with no sub-agents.""" + with pytest.raises(ValueError, match="ArrayIteratorAgent requires exactly one sub-agent"): + ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[] + ) + + def test_init_multiple_sub_agents(self): + """Test initialization fails with multiple sub-agents.""" + sub_agent1 = MockAgent("sub_agent1") + sub_agent2 = MockAgent("sub_agent2") + + with pytest.raises(ValueError, match="ArrayIteratorAgent accepts only one sub-agent, but 2 were provided"): + ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[sub_agent1, sub_agent2] + ) + + def test_custom_configuration(self): + """Test custom item_key and output_key configuration.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="data.items", + item_key="current_data", + output_key="results.processed", + sub_agents=[sub_agent] + ) + assert agent.array_key == "data.items" + assert agent.item_key == "current_data" + assert agent.output_key == "results.processed" + + @pytest.mark.asyncio + async def test_run_async_simple_array(self): + """Test basic array iteration.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + item_key="current_item", + output_key="results", + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State( + value={"items": ["item1", "item2", "item3"]}, + delta={} + ) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # Verify results + assert len(events) == 3 # One event per item + assert sub_agent.call_count == 3 + assert state["results"] == ["processed_item1", "processed_item2", "processed_item3"] + + @pytest.mark.asyncio + async def test_run_async_nested_array(self): + """Test nested array access.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="data.items", + item_key="current_item", + sub_agents=[sub_agent] + ) + + # Mock context with nested data + ctx = MagicMock() + state = State( + value={"data": {"items": ["nested1", "nested2"]}}, + delta={} + ) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # Verify results + assert len(events) == 2 + assert sub_agent.call_count == 2 + + @pytest.mark.asyncio + async def test_run_async_missing_array_key(self): + """Test handling of missing array key.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="missing_key", + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State(value={}, delta={}) + ctx.session.state = state + + # Should raise ValueError + with pytest.raises(ValueError, match="Array key 'missing_key' not found or invalid"): + async for event in agent._run_async_impl(ctx): + pass + + @pytest.mark.asyncio + async def test_run_async_non_array_value(self): + """Test handling of non-array value.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="not_array", + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State(value={"not_array": "string_value"}, delta={}) + ctx.session.state = state + + # Should raise TypeError + with pytest.raises(TypeError, match="Value at 'not_array' is not a list"): + async for event in agent._run_async_impl(ctx): + pass + + @pytest.mark.asyncio + async def test_run_async_empty_array(self): + """Test handling of empty array.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State(value={"items": []}, delta={}) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # Should process no items + assert len(events) == 0 + assert sub_agent.call_count == 0 + + @pytest.mark.asyncio + async def test_run_async_escalation_handling(self): + """Test that escalation stops iteration.""" + sub_agent = MockEscalatingAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State( + value={"items": ["item1", "item2", "item3", "item4"]}, + delta={} + ) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # Should stop after second item due to escalation + assert len(events) == 2 + assert sub_agent.call_count == 2 + + @pytest.mark.asyncio + async def test_run_async_state_restoration(self): + """Test that item_key is properly restored.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + item_key="test_key", + sub_agents=[sub_agent] + ) + + # Mock context with existing value for item_key + ctx = MagicMock() + state = State( + value={"items": ["item1"], "test_key": "original_value"}, + delta={} + ) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # Original value should be restored + assert state["test_key"] == "original_value" + + @pytest.mark.asyncio + async def test_run_async_no_output_key(self): + """Test iteration without collecting results.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + # No output_key specified + sub_agents=[sub_agent] + ) + + # Mock context + ctx = MagicMock() + state = State( + value={"items": ["item1", "item2"]}, + delta={} + ) + ctx.session.state = state + + # Run the agent + events = [] + async for event in agent._run_async_impl(ctx): + events.append(event) + + # No results should be stored + assert len(events) == 2 + assert "results" not in state + + @pytest.mark.asyncio + async def test_run_live_impl_not_supported(self): + """Test that live mode raises NotImplementedError.""" + sub_agent = MockAgent("sub_agent") + agent = ArrayIteratorAgent( + name="iterator", + array_key="items", + sub_agents=[sub_agent] + ) + + ctx = MagicMock() + + with pytest.raises(NotImplementedError, match="Live mode is not supported"): + async for event in agent._run_live_impl(ctx): + pass \ No newline at end of file