Skip to content

Add logger interface support #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/codegen/agents/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from langgraph.graph.graph import CompiledGraph
from langsmith import Client

from codegen.agents.loggers import ExternalLogger
from codegen.agents.tracer import MessageStreamTracer
from codegen.extensions.langchain.agent import create_codebase_agent
from codegen.extensions.langchain.utils.get_langsmith_url import (
find_and_print_langsmith_run_url,
Expand All @@ -30,6 +32,7 @@ class CodeAgent:
run_id: str | None = None
instance_id: str | None = None
difficulty: int | None = None
logger: Optional[ExternalLogger] = None

def __init__(
self,
Expand All @@ -42,6 +45,7 @@ def __init__(
metadata: Optional[dict] = {},
agent_config: Optional[AgentConfig] = None,
thread_id: Optional[str] = None,
logger: Optional[ExternalLogger] = None,
**kwargs,
):
"""Initialize a CodeAgent.
Expand Down Expand Up @@ -92,6 +96,9 @@ def __init__(
# Initialize tags for agent trace
self.tags = [*tags, self.model_name]

# set logger if provided
self.logger = logger

# Initialize metadata for agent trace
self.metadata = {
"project": self.project_name,
Expand Down Expand Up @@ -123,19 +130,26 @@ def run(self, prompt: str) -> str:

config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200)
# we stream the steps instead of invoke because it allows us to access intermediate nodes

stream = self.agent.stream(input, config=config, stream_mode="values")

_tracer = MessageStreamTracer(logger=self.logger)

# Process the stream with the tracer
traced_stream = _tracer.process_stream(stream)

# Keep track of run IDs from the stream
run_ids = []

for s in stream:
for s in traced_stream:
if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage):
message = HumanMessage(content=prompt)
else:
message = s["messages"][-1]

if isinstance(message, tuple):
print(message)
# print(message)
pass
else:
if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]:
AIMessage(message.content[0]["text"]).pretty_print()
Expand All @@ -149,7 +163,7 @@ def run(self, prompt: str) -> str:
# Get the last message content
result = s["final_answer"]

# Try to find run IDs in the LangSmith client's recent runs
# # Try to find run IDs in the LangSmith client's recent runs
try:
# Find and print the LangSmith run URL
find_and_print_langsmith_run_url(self.langsmith_client, self.project_name)
Expand Down
71 changes: 71 additions & 0 deletions src/codegen/agents/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Literal, Optional, Union


# Base dataclass for all message types
@dataclass
class BaseMessage:
"""Base class for all message types."""

type: str
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
content: str = ""


@dataclass
class UserMessage(BaseMessage):
"""Represents a message from the user."""

type: Literal["user"] = field(default="user")


@dataclass
class SystemMessageData(BaseMessage):
"""Represents a system message."""

type: Literal["system"] = field(default="system")


@dataclass
class ToolCall:
"""Represents a tool call within an assistant message."""

name: Optional[str] = None
arguments: Optional[str] = None
id: Optional[str] = None


@dataclass
class AssistantMessage(BaseMessage):
"""Represents a message from the assistant."""

type: Literal["assistant"] = field(default="assistant")
tool_calls: list[ToolCall] = field(default_factory=list)


@dataclass
class ToolMessageData(BaseMessage):
"""Represents a tool response message."""

type: Literal["tool"] = field(default="tool")
tool_name: Optional[str] = None
tool_response: Optional[str] = None
tool_id: Optional[str] = None


@dataclass
class FunctionMessageData(BaseMessage):
"""Represents a function message."""

type: Literal["function"] = field(default="function")


@dataclass
class UnknownMessage(BaseMessage):
"""Represents an unknown message type."""

type: Literal["unknown"] = field(default="unknown")


type AgentRunMessage = Union[UserMessage, SystemMessageData, AssistantMessage, ToolMessageData, FunctionMessageData, UnknownMessage]
16 changes: 16 additions & 0 deletions src/codegen/agents/loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Protocol

from .data import AgentRunMessage


# Define the interface for ExternalLogger
class ExternalLogger(Protocol):
"""Protocol defining the interface for external loggers."""

def log(self, data: AgentRunMessage) -> None:
"""Log structured data to an external system.

Args:
data: The structured data to log, either as a dictionary or a BaseMessage
"""
pass
91 changes: 91 additions & 0 deletions src/codegen/agents/scratch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from codegen.agents.code_agent import CodeAgent\n",
"\n",
"\n",
"CodeAgent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from codegen.sdk.core.codebase import Codebase\n",
"\n",
"\n",
"codebase = Codebase.from_repo(\"codegen-sh/Kevin-s-Adventure-Game\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Dict, Union\n",
"from codegen.agents.data import BaseMessage\n",
"from codegen.agents.loggers import ExternalLogger\n",
"\n",
"\n",
"class ConsoleLogger(ExternalLogger):\n",
" def log(self, data: Union[Dict[str, Any], BaseMessage]) -> None:\n",
" print(data.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent = CodeAgent(codebase)\n",
"agent.run(\"What is the main character's name? also show the source code where you find the answer\", logger=ConsoleLogger())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(\"What is the main character's name?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
136 changes: 136 additions & 0 deletions src/codegen/agents/tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from collections.abc import Generator
from typing import Any, Optional

from langchain.schema import AIMessage, HumanMessage
from langchain.schema import FunctionMessage as LCFunctionMessage
from langchain.schema import SystemMessage as LCSystemMessage
from langchain_core.messages import ToolMessage as LCToolMessage

from .data import AssistantMessage, BaseMessage, FunctionMessageData, SystemMessageData, ToolCall, ToolMessageData, UnknownMessage, UserMessage
from .loggers import ExternalLogger


class MessageStreamTracer:
def __init__(self, logger: Optional[ExternalLogger] = None):
self.traces = []
self.logger = logger

def process_stream(self, message_stream: Generator) -> Generator:
"""Process the stream of messages from the LangGraph agent,
extract structured data, and pass through the messages.
"""
for chunk in message_stream:
# Process the chunk
structured_data = self.extract_structured_data(chunk)

# Log the structured data
if structured_data:
self.traces.append(structured_data)

# If there's an external logger, send the data there
if self.logger:
self.logger.log(structured_data)

# Pass through the chunk to maintain the original stream behavior
yield chunk

def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage]:
"""Extract structured data from a message chunk.
Returns None if the chunk doesn't contain useful information.
Returns a BaseMessage subclass instance based on the message type.
"""
# Get the messages from the chunk if available
messages = chunk.get("messages", [])
if not messages and isinstance(chunk, dict):
# Sometimes the message might be in a different format
for key, value in chunk.items():
if isinstance(value, list) and all(hasattr(item, "type") for item in value if hasattr(item, "__dict__")):
messages = value
break

if not messages:
return None

# Get the latest message
latest_message = messages[-1] if messages else None

if not latest_message:
return None

# Determine message type
message_type = self._get_message_type(latest_message)
content = self._get_message_content(latest_message)

# Create the appropriate message type
if message_type == "user":
return UserMessage(type=message_type, content=content)
elif message_type == "system":
return SystemMessageData(type=message_type, content=content)
elif message_type == "assistant":
tool_calls_data = self._extract_tool_calls(latest_message)
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
elif message_type == "tool":
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
elif message_type == "function":
return FunctionMessageData(type=message_type, content=content)
else:
return UnknownMessage(type=message_type, content=content)

def _get_message_type(self, message) -> str:
"""Determine the type of message."""
if isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "assistant"
elif isinstance(message, LCSystemMessage):
return "system"
elif isinstance(message, LCFunctionMessage):
return "function"
elif isinstance(message, LCToolMessage):
return "tool"
elif hasattr(message, "type") and message.type:
return message.type
else:
return "unknown"

def _get_message_content(self, message) -> str:
"""Extract content from a message."""
if hasattr(message, "content"):
return message.content
elif hasattr(message, "message") and hasattr(message.message, "content"):
return message.message.content
else:
return str(message)

def _extract_tool_calls(self, message) -> list[dict[str, Any]]:
"""Extract tool calls from an assistant message."""
tool_calls = []

# Check different possible locations for tool calls
if hasattr(message, "additional_kwargs") and "tool_calls" in message.additional_kwargs:
raw_tool_calls = message.additional_kwargs["tool_calls"]
for tc in raw_tool_calls:
tool_calls.append({"name": tc.get("function", {}).get("name"), "arguments": tc.get("function", {}).get("arguments"), "id": tc.get("id")})

# Also check for function_call which is used in some models
elif hasattr(message, "additional_kwargs") and "function_call" in message.additional_kwargs:
fc = message.additional_kwargs["function_call"]
if isinstance(fc, dict):
tool_calls.append(
{
"name": fc.get("name"),
"arguments": fc.get("arguments"),
"id": "function_call_1", # Assigning a default ID
}
)

return tool_calls

def get_traces(self) -> list[BaseMessage]:
"""Get all collected traces."""
return self.traces

def clear_traces(self) -> None:
"""Clear all traces."""
self.traces = []
Loading