From a0425a238fedcccdf9300c3e2c3a3191fdd808ff Mon Sep 17 00:00:00 2001 From: Carl Gabel Date: Wed, 5 Mar 2025 23:05:44 +0000 Subject: [PATCH 1/4] Added Gemin support - refactored due to file size --- trustcall/_base.py | 1684 +-------------------------------------- trustcall/extract.py | 900 +++++++++++++++++++++ trustcall/patch.py | 181 +++++ trustcall/schema.py | 524 ++++++++++++ trustcall/states.py | 146 ++++ trustcall/tools.py | 163 ++++ trustcall/types.py | 92 +++ trustcall/utils.py | 136 ++++ trustcall/validation.py | 97 +++ 9 files changed, 2270 insertions(+), 1653 deletions(-) create mode 100644 trustcall/extract.py create mode 100644 trustcall/patch.py create mode 100644 trustcall/schema.py create mode 100644 trustcall/states.py create mode 100644 trustcall/tools.py create mode 100644 trustcall/types.py create mode 100644 trustcall/utils.py create mode 100644 trustcall/validation.py diff --git a/trustcall/_base.py b/trustcall/_base.py index 63e73fc..f86dd2b 100644 --- a/trustcall/_base.py +++ b/trustcall/_base.py @@ -1,1664 +1,42 @@ -"""Utilities for tool calling and extraction with retries.""" +"""Facade module for the trustcall package. -from __future__ import annotations - -import functools -import inspect -import json -import logging -import operator -import uuid -from dataclasses import asdict, dataclass, field -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Mapping, - NamedTuple, - Optional, - Sequence, - Type, - Union, - cast, -) +This module re-exports the key functionality from the other modules in the package. +It serves as the main entry point to the library, providing a simplified interface +for users. +""" -import jsonpatch # type: ignore[import-untyped] -import langsmith as ls -from dydantic import create_model_from_schema -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import ( - AIMessage, - AnyMessage, - BaseMessage, - HumanMessage, - MessageLikeRepresentation, - SystemMessage, - ToolCall, - ToolMessage, +# Re-export key functionality +from trustcall.extract import ( + create_extractor, ) -from langchain_core.prompt_values import PromptValue -from langchain_core.runnables import Runnable, RunnableConfig -from langchain_core.tools import BaseTool, InjectedToolArg, create_schema_from_function -from langgraph.constants import Send -from langgraph.graph import StateGraph, add_messages -from langgraph.prebuilt.tool_validator import ValidationNode, get_executor_for_config -from langgraph.types import Command -from langgraph.utils.runnable import RunnableCallable -from pydantic import ( - BaseModel, - ConfigDict, - Field, - StrictBool, - StrictFloat, - StrictInt, - create_model, - field_validator, +from trustcall.states import ( + ExtractionState, + ExtendedExtractState, + DeletionState, ) -from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict - -logger = logging.getLogger("extraction") - - -TOOL_T = Union[BaseTool, Type[BaseModel], Callable, Dict[str, Any]] -DEFAULT_MAX_ATTEMPTS = 3 - -Message = Union[AnyMessage, MessageLikeRepresentation] - -Messages = Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]] - - -class SchemaInstance(NamedTuple): - """Represents an instance of a schema with its associated metadata. - - This named tuple is used to store information about a specific schema instance, - including its unique identifier, the name of the schema it conforms to, - and the actual data of the record. - - Attributes: - record_id (str): A unique identifier for this schema instance. - schema_name (str): The name of the schema that this instance conforms to. - record (dict[str, Any]): The actual data of the record, stored as a dictionary. - """ - - record_id: str - schema_name: str | Literal["__any__"] - record: Dict[str, Any] - - -ExistingType = Union[ - Dict[str, Any], List[SchemaInstance], List[tuple[str, str, dict[str, Any]]] -] -"""Type for existing schemas. - -Can be one of: -- Dict[str, Any]: A dictionary mapping schema names to schema instances. -- List[SchemaInstance]: A list of SchemaInstance named tuples. -- List[tuple[str, str, dict[str, Any]]]: A list of tuples containing - (record_id, schema_name, record_dict). - -This type allows for flexibility in representing existing schemas, -supporting both single and multiple instances of each schema type. -""" - - -class ExtractionInputs(TypedDict, total=False): - messages: Union[Messages, PromptValue] - existing: Optional[ExistingType] - """Existing schemas. Key is the schema name, value is the schema instance. - If a list, supports duplicate schemas to update. - """ - - -InputsLike = Union[ExtractionInputs, List[AnyMessage], PromptValue, str] - - -class ExtractionOutputs(TypedDict): - messages: List[AIMessage] - responses: List[BaseModel] - response_metadata: List[dict[str, Any]] - attempts: int - - -def create_extractor( - llm: str | BaseChatModel, - *, - tools: Sequence[TOOL_T], - tool_choice: Optional[str] = None, - enable_inserts: bool = False, - enable_updates: bool = True, - enable_deletes: bool = False, - existing_schema_policy: bool | Literal["ignore"] = True, -) -> Runnable[InputsLike, ExtractionOutputs]: - """Create an extractor that generates validated structured outputs using an LLM. - - This function binds validators and retry logic to ensure the validity of - generated tool calls. It uses JSONPatch to correct validation errors caused - by incorrect or incomplete parameters in previous tool calls. - - Args: - llm (BaseChatModel): The language model that will generate the initial - messages and fallbacks. - tools (Sequence[TOOL_T]): The tools to bind to the LLM. Can be BaseTool, - Type[BaseModel], Callable, or Dict[str, Any]. - tool_choice (Optional[str]): The specific tool to use. If None, - the LLM chooses whether to use (or not use) a tool based - on the input messages. (default: None) - enable_inserts (bool): Whether to allow the LLM to extract new schemas - even if it receives existing schemas. (default: False) - enable_updates (bool): Whether to allow the LLM to update existing schemas - using the PatchDoc tool. (default: True) - enable_deletes (bool): Whether to allow the LLM to delete existing schemas - using the RemoveDoc tool. (default: False) - existing_schema_policy (bool | Literal["ignore"]): How to handle existing schemas - that don't match the provided tool. Useful for migrating or managing heterogenous - docs. (default: True) True means raise error. False means treat as dict. - "ignore" means ignore (drop any attempts to patch these) - - Returns: - Runnable[ExtractionInputs, ExtractionOutputs]: A runnable that - can be invoked with a list of messages and returns validated AI - messages and responses. - - Examples: - >>> from langchain_fireworks import ( - ... ChatFireworks, - ... ) - >>> from pydantic import ( - ... BaseModel, - ... Field, - ... ) - >>> - >>> class UserInfo(BaseModel): - ... name: str = Field(description="User's full name") - ... age: int = Field(description="User's age in years") - >>> - >>> llm = ChatFireworks(model="accounts/fireworks/models/firefunction-v2") - >>> extractor = create_extractor( - ... llm, - ... tools=[UserInfo], - ... ) - >>> result = extractor.invoke( - ... { - ... "messages": [ - ... ( - ... "human", - ... "My name is Alice and I'm 30 years old", - ... ) - ... ] - ... } - ... ) - >>> result["responses"][0] - UserInfo(name='Alice', age=30) - - Using multiple tools - >>> from typing import ( - ... List, - ... ) - >>> - >>> class Preferences(BaseModel): - ... foods: List[str] = Field(description="Favorite foods") - >>> - >>> extractor = create_extractor( - ... llm, - ... tools=[ - ... UserInfo, - ... Preferences, - ... ], - ... ) - >>> result = extractor.invoke( - ... { - ... "messages": [ - ... ( - ... "system", - ... "Extract all the user's information and preferences" - ... "from the conversation below using parallel tool calling.", - ... ), - ... ( - ... "human", - ... "I'm Bob, 25 years old, and I love pizza and sushi", - ... ), - ... ] - ... } - ... ) - >>> print(result["responses"]) - [UserInfo(name='Bob', age=25), Preferences(foods=['pizza', 'sushi'])] - >>> print(result["messages"]) # doctest: +SKIP - [ - AIMessage( - content='', tool_calls=[ - ToolCall(id='...', name='UserInfo', args={'name': 'Bob', 'age': 25}), - ToolCall(id='...', name='Preferences', args={'foods': ['pizza', 'sushi']} - )] - ) - ] - - Updating an existing schema: - >>> existing = { - ... "UserInfo": { - ... "name": "Alice", - ... "age": 30, - ... }, - ... "Preferences": { - ... "foods": [ - ... "pizza", - ... "sushi", - ... ] - ... }, - ... } - >>> extractor = create_extractor( - ... llm, - ... tools=[ - ... UserInfo, - ... Preferences, - ... ], - ... ) - >>> result = extractor.invoke( - ... { - ... "messages": [ - ... ( - ... "system", - ... "You are tasked with maintaining user info and preferences." - ... " Use the tools to update the schemas.", - ... ), - ... ( - ... "human", - ... "I'm Alice; just had my 31st birthday yesterday." - ... " We had spinach, which is my FAVORITE!", - ... ), - ... ], - ... "existing": existing, - ... } - ... ) - """ # noqa - if isinstance(llm, str): - try: - from langchain.chat_models import init_chat_model - except ImportError: - raise ImportError( - "Creating extractors from a string requires langchain>=0.3.0," - " as well as the provider-specific package" - " (like langchain-openai, langchain-anthropic, etc.)" - " Please install langchain to continue." - ) - llm = init_chat_model(llm) - builder = StateGraph(ExtractionState) - - def format_exception(error: BaseException, call: ToolCall, schema: Type[BaseModel]): - return ( - f"Error:\n\n```\n{str(error)}\n```\n" - "Expected Parameter Schema:\n\n" + f"```json\n{_get_schema(schema)}\n```\n" - f"Please use PatchFunctionErrors to fix all validation errors." - f" for json_doc_id=[{call['id']}]." - ) - - validator = _ExtendedValidationNode( - ensure_tools(tools) + [PatchDoc, PatchFunctionErrors], - format_error=format_exception, # type: ignore - enable_deletes=enable_deletes, - ) - _extract_tools = [ - schema - for name, schema in validator.schemas_by_name.items() - if name not in {PatchDoc.__name__, PatchFunctionErrors.__name__} - ] - tool_names = [getattr(t, "name", t.__name__) for t in _extract_tools] - builder.add_node( - _Extract( - llm, - _extract_tools, - tool_choice, - ).as_runnable() - ) - updater = _ExtractUpdates( - llm, - tools=validator.schemas_by_name.copy(), - enable_inserts=enable_inserts, # type: ignore - enable_updates=enable_updates, # type: ignore - enable_deletes=enable_deletes, # type: ignore - existing_schema_policy=existing_schema_policy, - ) - builder.add_node(updater.as_runnable()) - builder.add_node(_Patch(llm, valid_tool_names=tool_names).as_runnable()) - builder.add_node("validate", validator) - - def del_tool_call(state: DeletionState) -> dict: - return { - "messages": MessageOp(op="delete", target=state.deletion_target), - } - - builder.add_node(del_tool_call) - - def enter(state: ExtractionState) -> Literal["extract", "extract_updates"]: - if state.existing: - return "extract_updates" - return "extract" - - builder.add_conditional_edges("__start__", enter) - - def validate_or_retry( - state: ExtractionState, - ) -> Literal["validate", "extract_updates"]: - if state.messages[-1].type == "ai": - return "validate" - return "extract_updates" - - builder.add_edge("extract", "validate") - builder.add_conditional_edges("extract_updates", validate_or_retry) - - def handle_retries( - state: ExtractionState, config: RunnableConfig - ) -> Union[Literal["__end__"], list]: - """After validation, decide whether to retry or end the process.""" - max_attempts = config["configurable"].get("max_attempts", DEFAULT_MAX_ATTEMPTS) - if state.attempts >= max_attempts: - return "__end__" - # Only continue if we need to patch the tool call - to_send = [] - # We only increment the attempt count once, regardless of the fan-out - # degree. - bumped = False - for m in reversed(state.messages): - if isinstance(m, AIMessage): - break - if isinstance(m, ToolMessage): - if m.status == "error": - # Each fallback will fix at most 1 schema per time. - messages_for_fixing = _get_history_for_tool_call( - state.messages, m.tool_call_id - ) - to_send.append( - Send( - "patch", - ExtendedExtractState( - **{ - **asdict(state), - "messages": messages_for_fixing, - "tool_call_id": m.tool_call_id, - "bump_attempt": not bumped, - } - ), - ) - ) - bumped = True - else: - # We want to delete the validation tool calls - # anyway to avoid mixing branches during fan-in - to_send.append( - Send( - "del_tool_call", - DeletionState( - deletion_target=str(m.id), messages=state.messages - ), - ) - ) - return to_send - - builder.add_conditional_edges( - "validate", handle_retries, path_map=["__end__", "patch", "del_tool_call"] - ) - - def sync(state: ExtractionState, config: RunnableConfig) -> dict: - return {"messages": []} - - def validate_or_repatch( - state: ExtractionState, - ) -> Literal["validate", "patch"]: - if state.messages[-1].type == "ai": - return "validate" - return "patch" - - builder.add_node(sync) - - builder.add_conditional_edges( - "sync", validate_or_repatch, path_map=["validate", "patch", "__end__"] - ) - compiled = builder.compile(checkpointer=False) - compiled.name = "TrustCall" - - def filter_state(state: dict) -> ExtractionOutputs: - """Filter the state to only include the validated AIMessage + responses.""" - msg_id = state["msg_id"] - msg: Optional[AIMessage] = next( - ( - m - for m in state["messages"] - if m.id == msg_id and isinstance(m, AIMessage) - ), # type: ignore - None, - ) - if not msg: - return ExtractionOutputs( - messages=[], - responses=[], - attempts=state["attempts"], - response_metadata=[], - ) - responses = [] - response_metadata = [] - updated_docs = msg.additional_kwargs.get("updated_docs") or {} - existing = state.get("existing") - removal_schema = None - if enable_deletes and existing: - removal_schema = _create_remove_doc_from_existing(existing) - for tc in msg.tool_calls: - if removal_schema and tc["name"] == removal_schema.__name__: - sch = removal_schema - elif tc["name"] not in validator.schemas_by_name: - if existing_schema_policy in (False, "ignore"): - continue - sch = validator.schemas_by_name[tc["name"]] - else: - sch = validator.schemas_by_name[tc["name"]] - try: - responses.append( - sch.model_validate(tc["args"]) - if hasattr(sch, "model_validate") - else sch.parse_obj(tc["args"]) - ) - meta = {"id": tc["id"]} - if json_doc_id := updated_docs.get(tc["id"]): - meta["json_doc_id"] = json_doc_id - response_metadata.append(meta) - except Exception as e: - logger.error(e) - continue - - return { - "messages": [msg], - "responses": responses, - "response_metadata": response_metadata, - "attempts": state["attempts"], - } - - def coerce_inputs(state: InputsLike) -> Union[ExtractionInputs, dict]: - """Coerce inputs to the expected format.""" - if isinstance(state, list): - return {"messages": state} - if isinstance(state, str): - return {"messages": [{"role": "user", "content": state}]} - if isinstance(state, PromptValue): - return {"messages": state.to_messages()} - if isinstance(state, dict): - if isinstance(state.get("messages"), PromptValue): - state = {**state, "messages": state["messages"].to_messages()} # type: ignore - else: - if hasattr(state, "messages"): - state = {"messages": state.messages.to_messages()} # type: ignore - - return cast(dict, state) - - return coerce_inputs | compiled | filter_state - - -## Helper functions + reducers - - -def ensure_tools( - tools: Sequence[TOOL_T], -) -> List[Union[BaseTool, Type[BaseModel], Callable]]: - results: list = [] - for t in tools: - if isinstance(t, dict): - if all(k in t for k in ("name", "description", "parameters")): - schema = create_model_from_schema( - {"title": t["name"], **t["parameters"]} - ) - schema.__doc__ = (getattr(schema, __doc__, "") or "") + ( - t.get("description") or "" - ) - schema.__name__ = t["name"] - results.append(schema) - elif all(k in t for k in ("type", "function")): - # Already in openai format - resolved = ensure_tools([t["function"]]) - results.extend(resolved) - else: - model = create_model_from_schema(t) - if not model.__doc__: - model.__doc__ = t.get("description") or model.__name__ - results.append(model) - elif is_typeddict(t): - results.append(_convert_any_typed_dicts_to_pydantic(cast(type, t))) - elif isinstance(t, (BaseTool, type)): - results.append(t) - elif callable(t): - results.append(csff_(t)) - else: - raise ValueError(f"Invalid tool type: {type(t)}") - return list(results) - - -_MAX_TYPED_DICT_RECURSION = 25 - - -def _convert_any_typed_dicts_to_pydantic( - type_: type, - *, - visited: dict | None = None, - depth: int = 0, -) -> type: - visited = visited if visited is not None else {} - if type_ in visited: - return visited[type_] - elif depth >= _MAX_TYPED_DICT_RECURSION: - return type_ - elif is_typeddict(type_): - typed_dict = type_ - docstring = inspect.getdoc(typed_dict) - annotations_ = typed_dict.__annotations__ - fields: dict = {} - for arg, arg_type in annotations_.items(): - if get_origin(arg_type) is Annotated: - annotated_args = get_args(arg_type) - new_arg_type = _convert_any_typed_dicts_to_pydantic( - annotated_args[0], depth=depth + 1, visited=visited - ) - field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) - if (field_desc := field_kwargs.get("description")) and not isinstance( - field_desc, str - ): - raise ValueError( - f"Invalid annotation for field {arg}. Third argument to " - f"Annotated must be a string description, received value of " - f"type {type(field_desc)}." - ) - else: - pass - fields[arg] = (new_arg_type, Field(**field_kwargs)) - else: - new_arg_type = _convert_any_typed_dicts_to_pydantic( - arg_type, depth=depth + 1, visited=visited - ) - field_kwargs = {"default": ...} - fields[arg] = (new_arg_type, Field(**field_kwargs)) - model = create_model(typed_dict.__name__, **fields) - model.__doc__ = docstring or "" - visited[typed_dict] = model - return model - elif (origin := get_origin(type_)) and (type_args := get_args(type_)): - type_args = tuple( - _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) - for arg in type_args # type: ignore[index] - ) - return origin[type_args] # type: ignore[index] - else: - return type_ - - -def _exclude_none(d: Dict[str, Any]) -> Dict[str, Any]: - return { - k: v if not isinstance(v, dict) else _exclude_none(v) - for k, v in d.items() - if v is not None - } - - -def _get_schema(model: Type[BaseModel]) -> dict: - if hasattr(model, "model_json_schema"): - schema = model.model_json_schema() - else: - schema = model.schema() # type: ignore - return _exclude_none(schema) - - -class _Extract: - def __init__( - self, - llm: BaseChatModel, - tools: Sequence, - tool_choice: Optional[str] = None, - ): - self.bound_llm = llm.bind_tools( - [ - { - "type": "function", - "function": { - "name": getattr(t, "name", t.__name__), - "description": t.__doc__, - "parameters": _get_schema(t), - }, - } - for t in tools - ], - tool_choice=tool_choice, - ) - - @ls.traceable - def _tear_down(self, msg: AIMessage) -> dict: - if not msg.id: - msg.id = str(uuid.uuid4()) - return { - "messages": [msg], - "attempts": 1, - "msg_id": msg.id, - } - - async def ainvoke(self, state: ExtractionState, config: RunnableConfig) -> dict: - msg = await self.bound_llm.ainvoke(state.messages, config) - return self._tear_down(cast(AIMessage, msg)) - - def invoke(self, state: ExtractionState, config: RunnableConfig) -> dict: - msg = self.bound_llm.invoke(state.messages, config) - return self._tear_down(msg) - - def as_runnable(self): - return RunnableCallable(self.invoke, self.ainvoke, name="extract", trace=False) - - -class _ExtractUpdates: - """Prompt an LLM to patch an existing schema. - - We have found this to be prefereable to re-generating - the entire tool call from scratch in several ways: - - 1. Fewer output tokens. - 2. Less likely to introduce new errors or drop important information. - 3. Easier for the LLM to generate. - """ - - def __init__( - self, - llm: BaseChatModel, - tools: Mapping[str, Type[BaseModel]], - enable_inserts: bool = False, - enable_updates: bool = True, - enable_deletes: bool = False, - existing_schema_policy: bool | Literal["ignore"] = True, - ): - if not any((enable_inserts, enable_updates, enable_deletes)): - raise ValueError( - "At least one of enable_inserts, enable_updates," - " or enable_deletes must be True." - ) - new_tools: list = [PatchDoc] if enable_updates else [] - tool_choice = "PatchDoc" if not enable_deletes else "any" - if enable_inserts: # Also let the LLM know that we can extract NEW schemas. - tools_ = [ - schema - for name, schema in (tools or {}).items() - if name not in {PatchDoc.__name__, PatchFunctionErrors.__name__} - ] - new_tools.extend(tools_) - tool_choice = "any" - self.enable_inserts = enable_inserts - self.enable_updates = enable_updates - self.bound_tools = new_tools - self.tool_choice = tool_choice - self.bound = llm.bind_tools(new_tools, tool_choice=tool_choice) - self.enable_deletes = enable_deletes - self.tools = dict(tools) | {schema_.__name__: schema_ for schema_ in new_tools} - self.existing_schema_policy = existing_schema_policy - - @ls.traceable(tags=["langsmith:hidden"]) - def _setup(self, state: ExtractionState): - messages = state.messages - existing = state.existing - if not existing: - raise ValueError("No existing schemas provided.") - existing = self._validate_existing(existing) # type: ignore[assignment] - schema_strings = [] - if isinstance(existing, dict): - for k, v in existing.items(): - if k not in self.tools and self.existing_schema_policy is False: - schema_str = "object" - else: - schema = self.tools[k] - schema_json = schema.model_json_schema() - schema_str = f""" - - {schema_json} - -""" - schema_strings.append( - f"\n\n{v}\n" - f"{schema_str}" - ) - else: - for schema_id, tname, d in existing: - schema_strings.append( - f'\n{d}\n' - ) - - existing_schemas = "\n".join(schema_strings) - cmd = "Generate JSONPatches to update the existing schema instances." - if self.enable_inserts: - cmd += ( - " If you need to extract or insert *new* instances of the schemas" - ", call the relevant function(s)." - ) - - existing_msg = f"""{cmd} - -{existing_schemas} - -""" - if isinstance(messages[0], SystemMessage): - system_message = messages.pop(0) - if isinstance(system_message.content, str): - system_message.content += "\n\n" + existing_msg - else: - system_message.content = cast(list, system_message.content) + [ - "\n\n" + existing_msg - ] - else: - system_message = SystemMessage(content=existing_msg) - removal_schema = None - if self.enable_deletes and existing: - removal_schema = _create_remove_doc_from_existing(existing) - bound_model = self.bound.bound.bind_tools( # type: ignore - self.bound_tools + [removal_schema], - tool_choice=self.tool_choice, - ) - else: - bound_model = self.bound - - return [system_message] + messages, existing, removal_schema, bound_model - - @ls.traceable(tags=["langsmith:hidden"]) - def _teardown( - self, - msg: AIMessage, - existing: Union[Dict[str, Any], List[SchemaInstance]], - ): - resolved_tool_calls = [] - updated_docs = {} - rt = ls.get_current_run_tree() - for tc in msg.tool_calls: - if tc["name"] == PatchDoc.__name__: - json_doc_id = tc["args"]["json_doc_id"] - if isinstance(existing, dict): - target = existing.get(str(json_doc_id)) - tool_name = json_doc_id - else: - try: - _, tool_name, target = next( - (e for e in existing if e[0] == json_doc_id), - ) - if not tool_name: - raise ValueError( - "Could not find tool name " - f"for json_doc_id {json_doc_id}" - ) - except StopIteration: - logger.error( - f"Could not find existing schema in dict for {json_doc_id}" - ) - if rt: - rt.error = ( - f"Could not find existing schema for {json_doc_id}" - ) - continue - except (ValueError, IndexError, TypeError): - logger.error( - f"Could not find existing schema in list for {json_doc_id}" - ) - if rt: - rt.error = ( - f"Could not find existing schema for {json_doc_id}" - ) - continue - - if target: - try: - patches = _ensure_patches(tc["args"]) - if patches or self.tool_choice == "PatchDoc": - # The second condition is so that, when we are continuously - # updating a single doc, we will still include it in - # the output responses list; mainly for backwards - # compatibility - resolved_tool_calls.append( - ToolCall( - id=tc["id"], - name=tool_name, - args=jsonpatch.apply_patch(target, patches), - ) - ) - updated_docs[tc["id"]] = str(json_doc_id) - except Exception as e: - logger.error(f"Could not apply patch: {e}") - if rt: - rt.error = f"Could not apply patch: {repr(e)}" - else: - if rt: - rt.error = f"Could not find existing schema for {tool_name}" - logger.warning(f"Could not find existing schema for {tool_name}") - else: - resolved_tool_calls.append(tc) - ai_message = AIMessage( - content=msg.content, - tool_calls=resolved_tool_calls, - additional_kwargs={"updated_docs": updated_docs}, - ) - if not ai_message.id: - ai_message.id = str(uuid.uuid4()) - - return { - "messages": [ai_message], - "attempts": 1, - "msg_id": ai_message.id, - } - - @property - def _provided_tools(self): - return sorted(self.tools.keys() - {"PatchDoc", "PatchFunctionErrors"}) - - def _validate_existing( - self, existing: ExistingType - ) -> Union[Dict[str, Any], List[SchemaInstance]]: - """Check that all existing schemas match a known schema or '__any__'.""" - if isinstance(existing, dict): - # For each top-level key, see if it's recognized - validated = {} - for key, record in existing.items(): - if key in self.tools or key == "__any__": - validated[key] = record - else: - # Key does not match known schema - if self.existing_schema_policy is True: - raise ValueError( - f"Key '{key}' doesn't match any schema. " - f"Known schemas: {list(self.tools.keys())}" - ) - elif self.existing_schema_policy is False: - validated[key] = record - else: # "ignore" - logger.warning(f"Ignoring unknown schema: {key}") - return validated - - elif isinstance(existing, list): - # For list types, validate each item's schema_name - coerced = [] - for i, item in enumerate(existing): - if isinstance(item, SchemaInstance): - if ( - item.schema_name not in self.tools - and item.schema_name != "__any__" - ): - if self.existing_schema_policy is True: - raise ValueError( - f"Unknown schema '{item.schema_name}' at index {i}" - ) - elif self.existing_schema_policy is False: - coerced.append( - SchemaInstance( - item.record_id, item.schema_name, item.record - ) - ) - else: # "ignore" - logger.warning(f"Ignoring unknown schema at index {i}") - continue - else: - coerced.append(item) - elif isinstance(item, tuple) and len(item) == 3: - record_id, schema_name, record_dict = item - if isinstance(record_dict, BaseModel): - record_dict = record_dict.model_dump(mode="json") - if schema_name not in self.tools and schema_name != "__any__": - if self.existing_schema_policy is True: - raise ValueError( - f"Unknown schema '{schema_name}' at index {i}" - ) - elif self.existing_schema_policy is False: - coerced.append( - SchemaInstance(record_id, schema_name, record_dict) - ) - else: # "ignore" - logger.warning(f"Ignoring unknown schema '{schema_name}'") - continue - else: - coerced.append( - SchemaInstance(record_id, schema_name, record_dict) - ) - elif isinstance(item, tuple) and len(item) == 2: - # Assume record_ID, item - record_id, model = item - if hasattr(model, "__name__"): - schema_name = model.__name__ - else: - schema_name = model.__repr_name__() - - if schema_name not in self.tools and schema_name != "__any__": - if self.existing_schema_policy is True: - raise ValueError( - f"Unknown schema '{schema_name}' at index {i}" - ) - elif self.existing_schema_policy is False: - val = ( - model.model_dump(mode="json") - if isinstance(model, BaseModel) - else model - ) - coerced.append(SchemaInstance(record_id, schema_name, val)) - else: # "ignore" - logger.warning(f"Ignoring unknown schema '{schema_name}'") - continue - else: - val = ( - model.model_dump(mode="json") - if isinstance(model, BaseModel) - else model - ) - coerced.append(SchemaInstance(record_id, schema_name, val)) - elif isinstance(item, BaseModel): - if hasattr(item, "__name__"): - schema_name = item.__name__ - else: - schema_name = item.__repr_name__() - - if schema_name not in self.tools and schema_name != "__any__": - if self.existing_schema_policy is True: - raise ValueError( - f"Unknown schema '{schema_name}' at index {i}" - ) - elif self.existing_schema_policy is False: - coerced.append( - SchemaInstance( - str(uuid.uuid4()), - schema_name, - item.model_dump(mode="json"), - ) - ) - else: # "ignore" - logger.warning(f"Ignoring unknown schema '{schema_name}'") - continue - else: - coerced.append( - SchemaInstance( - str(uuid.uuid4()), - schema_name, - item.model_dump(mode="json"), - ) - ) - else: - raise ValueError( - f"Invalid item at index {i} in existing list." - f" Provided: {item}, Expected: SchemaInstance" - f" or Tuple[str, str, dict] or BaseModel" - ) - return coerced - else: - raise ValueError( - f"Invalid type for existing. Provided: {type(existing)}," - f" Expected: dict or list. Supported formats are:\n" - "1. Dict[str, Any] where keys are tool names\n" - "2. List[SchemaInstance]\n3. List[Tuple[str, str, Dict[str, Any]]]" - ) - - async def ainvoke(self, state: ExtractionState, config: RunnableConfig) -> dict: - """Generate a JSONPatch to simply update an existing schema. - - Returns a single AIMessage with the updated schema, as if - the schema were extracted from scratch. - """ - messages, existing, removal_schema, bound_model = self._setup(state) - try: - msg = await bound_model.ainvoke(messages, config) - return { - **self._teardown(cast(AIMessage, msg), existing), - "removal_schema": removal_schema, - } - except Exception as e: - return { - "messages": [ - HumanMessage( - content="Fix the validation error while" - f" also avoiding: {repr(str(e))}" - ) - ], - "attempts": 1, - } - - def invoke(self, state: ExtractionState, config: RunnableConfig) -> dict: - messages, existing, removal_schema, bound_model = self._setup(state) - try: - msg = bound_model.invoke(messages, config) - return {**self._teardown(msg, existing), "removal_schema": removal_schema} - except Exception as e: - return { - "messages": [ - HumanMessage( - content="Fix the validation error while" - f" also avoiding: {repr(str(e))}" - ) - ], - "attempts": 1, - } - - def as_runnable(self): - return RunnableCallable( - self.invoke, self.ainvoke, name="extract_updates", trace=False - ) - - -class _Patch: - """Prompt an LLM to patch an invalid schema after it receives a ValidationError. - - We have found this to be more reliable and more token-efficient than - re-creating the entire tool call from scratch. - """ - - def __init__( - self, llm: BaseChatModel, valid_tool_names: Optional[List[str]] = None - ): - self.bound = llm.bind_tools( - [PatchFunctionErrors, _create_patch_function_name_schema(valid_tool_names)], - tool_choice="any", - ) - - @ls.traceable(tags=["patch", "langsmith:hidden"]) - def _tear_down( - self, - msg: AIMessage, - messages: List[AnyMessage], - target_id: str, - bump_attempt: bool, - ): - if not msg.id: - msg.id = str(uuid.uuid4()) - # We will directly update the messages in the state before validation. - msg_ops = _infer_patch_message_ops(messages, msg.tool_calls, target_id) - return { - "messages": msg_ops, - "attempts": 1 if bump_attempt else 0, - } - - async def ainvoke( - self, state: ExtendedExtractState, config: RunnableConfig - ) -> Command[Literal["sync", "__end__"]]: - """Generate a JSONPatch to correct the validation error and heal the tool call. - - Assumptions: - - We only support a single tool call to be patched. - - State's message list's last AIMessage contains the actual schema to fix. - - The last ToolMessage contains the tool call to fix. - - """ - try: - msg = await self.bound.ainvoke(state.messages, config) - except Exception: - return Command(goto="__end__") - return Command( - update=self._tear_down( - cast(AIMessage, msg), - state.messages, - state.tool_call_id, - state.bump_attempt, - ), - goto=("sync",), - ) - - def invoke( - self, state: ExtendedExtractState, config: RunnableConfig - ) -> Command[Literal["sync", "__end__"]]: - try: - msg = self.bound.invoke(state.messages, config) - except Exception: - return Command(goto="__end__") - return Command( - update=self._tear_down( - cast(AIMessage, msg), - state.messages, - state.tool_call_id, - state.bump_attempt, - ), - goto=("sync",), - ) - - def as_runnable(self): - return RunnableCallable(self.invoke, self.ainvoke, name="patch", trace=False) - - -# We COULD just say Any for the value below, but Fireworks and some other -# providers don't support untyped arrays and dicts... -_JSON_PRIM_TYPES = Union[str, StrictInt, StrictBool, StrictFloat, None] -_JSON_TYPES = Union[ - _JSON_PRIM_TYPES, List[_JSON_PRIM_TYPES], Dict[str, _JSON_PRIM_TYPES] -] - - -class JsonPatch(BaseModel): - """A JSON Patch document represents an operation to be performed on a JSON document. - - Note that the op and path are ALWAYS required. Value is required for ALL operations except 'remove'. - """ # noqa - - op: Literal["add", "remove", "replace"] = Field( - ..., - description="The operation to be performed. Must be one" - " of 'add', 'remove', 'replace'.", - ) - path: str = Field( - ..., - description="A JSON Pointer path that references a location within the" - " target document where the operation is performed." - " Note: patches are applied sequentially. If you remove a value, the collection" - " size changes before the next patch is applied.", - ) - value: Union[_JSON_TYPES, List[_JSON_TYPES], Dict[str, _JSON_TYPES]] = Field( - ..., - description="The value to be used within the operation. REQUIRED for" - " 'add', 'replace', and 'test' operations." - " Pay close attention to the json schema to ensure" - " patched document will be valid.", - ) - model_config = ConfigDict( - json_schema_extra={ - "examples": [ - { - "op": "replace", - "path": "/path/to/my_array/1", - "value": "the newer value to be patched", - }, - { - "op": "replace", - "path": "/path/to/broken_object", - "value": {"new": "object"}, - }, - { - "op": "add", - "path": "/path/to/my_array/-", - "value": ["some", "values"], - }, - { - "op": "add", - "path": "/path/to/my_array/-", - "value": ["newer"], - }, - { - "op": "remove", - "path": "/path/to/my_array/1", - }, - ] - } - ) - - -def _create_remove_doc_from_existing(existing: Union[dict, list]): - if isinstance(existing, dict): - existing_ids = set(existing) - else: - existing_ids = set() - for schema_id, *_ in existing: - existing_ids.add(schema_id) - return _create_remove_doc_schema(tuple(sorted(existing_ids))) - - -@functools.lru_cache(maxsize=10) -def _create_remove_doc_schema(allowed_ids: tuple[str]) -> Type[BaseModel]: - """Create a RemoveDoc schema that validates against a set of allowed IDs.""" - - class RemoveDoc(BaseModel): - """Use this tool to remove (delete) a doc by its ID.""" - - json_doc_id: str = Field( - ..., - description=f"ID of the document to remove. Must be one of: {allowed_ids}", - ) - - @field_validator("json_doc_id") - @classmethod - def validate_doc_id(cls, v: str) -> str: - if v not in allowed_ids: - raise ValueError( - f"Document ID '{v}' not found. Available IDs: {sorted(allowed_ids)}" - ) - return v - - RemoveDoc.__name__ = "RemoveDoc" - return RemoveDoc - - -# Used for fixing validation errors -class PatchFunctionErrors(BaseModel): - """Respond with all JSONPatch operations required to update the previous invalid function call. - - Use to correct all validation errors in non-compliant function calls, or to extend or update existing structured data in the presence of new information. Closely analyze - the parameters from the original JSONSchema to ensure the patched document will be valid - and that you avoid repeating the same errors. - """ # noqa - - json_doc_id: str = Field( - ..., - description="The ID of the function you are patching.", - ) - planned_edits: str = Field( - ..., - description="Write a bullet-point list of each ValidationError you encountered" - " and the corresponding JSONPatch operation needed to heal it." - " For each operation, write why your initial guess was incorrect, " - " citing the corresponding types(s) from the JSONSchema" - " that will be used the validate the resultant patched document." - " Think step-by-step to ensure no error is overlooked.", - ) - patches: list[JsonPatch] = Field( - ..., - description="A list of JSONPatch operations to be applied to the" - " previous tool call's response arguments. If none are required, return" - " an empty list. This field is REQUIRED." - " Multiple patches in the list are applied sequentially in the order provided," - " with each patch building upon the result of the previous one.", - ) - - -def _create_patch_function_name_schema(valid_tool_names: Optional[List[str]] = None): - if valid_tool_names: - namestr = ", ".join(valid_tool_names) - vname = f" Must be one of {namestr}" - else: - vname = "" - - class PatchFunctionName(BaseModel): - """Call this if the tool message indicates that you previously invoked an invalid tool, (e.g., "Unrecognized tool name" error), do so here.""" # noqa - - json_doc_id: str = Field( - ..., - description="The ID of the function you are patching.", - ) - reasoning: list[str] = Field( - ..., - description="At least 2 logical reasons why this action ought to be taken." - "Cite the specific error(s) mentioned to motivate the fix.", - ) - fixed_name: Optional[str] = Field( - ..., - description="If you need to change the name of the function (e.g., " - f'from an "Unrecognized tool name" error), do so here.{vname}', - ) - - return PatchFunctionName - - -# Used for updating existing documents -class PatchDoc(BaseModel): - """Respond with JSONPatch operations to update the existing JSON document based on the provided text and schema.""" # noqa - - json_doc_id: str = Field( - ..., - description="The json_doc_id of the document you are patching.", - ) - planned_edits: str = Field( - ..., - description="Think step-by-step, reasoning over each required" - " update and the corresponding JSONPatch operation to accomplish it." - " Cite the fields in the JSONSchema you referenced in developing this plan." - " Address each path as a group; don't switch between paths.\n" - " Plan your patches in the following order:" - "1. replace - this keeps collection size the same.\n" - "2. remove - BE CAREFUL ABOUT ORDER OF OPERATIONS." - " Each operation is applied sequentially." - " For arrays, remove the highest indexed value first to avoid shifting" - " indices. This ensures subsequent remove operations remain valid.\n" - " 3. add (for arrays, use /- to efficiently append to end).", - ) - patches: list[JsonPatch] = Field( - ..., - description="A list of JSONPatch operations to be applied to the" - " previous tool call's response arguments. If none are required, return" - " an empty list. This field is REQUIRED." - " Multiple patches in the list are applied sequentially in the order provided," - " with each patch building upon the result of the previous one." - " Take care to respect array bounds. Order patches as follows:\n" - " 1. replace - this keeps collection size the same\n" - " 2. remove - BE CAREFUL about order of operations. For arrays, remove" - " the highest indexed value first to avoid shifting indices.\n" - " 3. add - for arrays, use /- to efficiently append to end.", - ) - - -class MessageOp(TypedDict): - op: Literal["delete", "update_tool_call", "update_tool_name"] - target: Union[str, ToolCall] - - -def _get_history_for_tool_call(messages: List[AnyMessage], tool_call_id: str): - results = [] - seen_ai_message = False - for m in reversed(messages): - if isinstance(m, AIMessage): - if not seen_ai_message: - tool_calls = [tc for tc in m.tool_calls if tc["id"] == tool_call_id] - if hasattr(m, "model_dump"): - d = m.model_dump(exclude={"tool_calls", "content"}) - else: - d = m.dict(exclude={"tool_calls", "content"}) - m = AIMessage( - **d, - # Frequently have partial_json blocks that are - # invalid if sent back to the API - content=str(m.content), - tool_calls=tool_calls, - ) - seen_ai_message = True - if isinstance(m, ToolMessage): - if m.tool_call_id != tool_call_id and not seen_ai_message: - continue - results.append(m) - return list(reversed(results)) - - -def _apply_message_ops( - messages: Sequence[AnyMessage], message_ops: Sequence[MessageOp] -) -> List[AnyMessage]: - # Apply operations to the messages - messages = list(messages) - for message_op in message_ops: - if message_op["op"] == "delete": - t = cast(str, message_op["target"]) - messages_ = [m for m in messages if cast(str, getattr(m, "id")) != t] - messages = messages_ - elif message_op["op"] == "update_tool_call": - targ = cast(ToolCall, message_op["target"]) - messages_ = [] - for m in messages: - if isinstance(m, AIMessage): - old = m.tool_calls.copy() - new = [ - targ if tc["id"] == targ["id"] else tc for tc in m.tool_calls - ] - if old != new: - m = m.model_copy() - m.tool_calls = new - if m.additional_kwargs.get("tool_calls"): - m.additional_kwargs["tool_calls"] = new - messages_.append(m) - else: - messages_.append(m) - messages = messages_ - elif message_op["op"] == "update_tool_name": - update_targ = cast(dict, message_op["target"]) - messages_ = [] - for m in messages: - if isinstance(m, AIMessage): - new = [] - for tc in m.tool_calls: - if tc["id"] == update_targ["id"]: - new.append( - { - "id": update_targ["id"], - "name": update_targ[ - "name" - ], # Just updating the name - "args": tc["args"], - } - ) - else: - new.append(tc) - if m.tool_calls != new: - m = m.model_copy() - m.tool_calls = new - messages_.append(m) - messages = messages_ - - else: - raise ValueError(f"Invalid operation: {message_op['op']}") - return messages - - -def _reduce_messages( - left: Optional[List[AnyMessage]], - right: Union[ - AnyMessage, - List[Union[AnyMessage, MessageOp]], - List[BaseMessage], - PromptValue, - MessageOp, - ], -) -> Messages: - if not left: - left = [] - if isinstance(right, PromptValue): - right = right.to_messages() - message_ops = [] - if isinstance(right, dict) and right.get("op"): - message_ops = [right] - right = [] - if isinstance(right, list): - right_ = [] - for r in right: - if isinstance(r, dict) and r.get("op"): - message_ops.append(r) - else: - right_.append(r) - right = right_ # type: ignore[assignment] - messages = cast(Sequence[AnyMessage], add_messages(left, right)) # type: ignore[arg-type] - if message_ops: - messages = _apply_message_ops(messages, message_ops) - return messages - - -def _get_message_op( - messages: Sequence[AnyMessage], tool_call: dict, tool_call_name: str, target_id: str -) -> List[MessageOp]: - msg_ops: List[MessageOp] = [] - rt = ls.get_current_run_tree() - for m in messages: - if isinstance(m, AIMessage): - for tc in m.tool_calls: - if tc["id"] == target_id: - if tool_call_name == "PatchFunctionName": - if not tool_call.get("fixed_name"): - continue - msg_ops.append( - { - "op": "update_tool_name", - "target": { # type: ignore[arg-type,typeddict-item] - "id": target_id, - "name": str(tool_call["fixed_name"]), - }, - } - ) - elif tool_call_name in ("PatchFunctionErrors", "PatchDoc"): - try: - patches = _ensure_patches(tool_call) - if patches: - patched_args = jsonpatch.apply_patch( - tc["args"], patches - ) - msg_ops.append( - { - "op": "update_tool_call", - "target": { - "id": target_id, - "name": tc["name"], - "args": patched_args, - }, - } - ) - except Exception as e: - if rt: - rt.error = f"Could not apply patch: {repr(e)}" - logger.error(f"Could not apply patch: {repr(e)}") - else: - if rt: - rt.error = f"Unrecognized function call {tool_call_name}" - logger.error(f"Unrecognized function call {tool_call_name}") - if isinstance(m, ToolMessage): - if m.tool_call_id == target_id: - msg_ops.append(MessageOp(op="delete", target=m.id or "")) - return msg_ops - - -@ls.traceable(tags=["langsmith:hidden"]) -def _infer_patch_message_ops( - messages: Sequence[AnyMessage], tool_calls: List[ToolCall], target_id: str -): - return [ - op - for tool_call in tool_calls - for op in _get_message_op( - messages, tool_call["args"], tool_call["name"], target_id=target_id - ) - ] - - -def csff_(function: Callable) -> Type[BaseModel]: - fn = _strip_injected(function) - schema = create_schema_from_function(function.__name__, fn) - schema.__name__ = function.__name__ - return schema - - -def _keep_first(left: Any, right: Any): - return left or right - - -@dataclass(kw_only=True) -class ExtractionState: - messages: Annotated[List[AnyMessage], _reduce_messages] = field( - default_factory=list - ) - attempts: Annotated[int, operator.add] = field(default=0) - msg_id: Annotated[str, _keep_first] = field(default="") - """Set once and never changed. The ID of the message to be patched.""" - existing: Optional[Dict[str, Any]] = field(default=None) - """If you're updating an existing schema, provide the existing schema here.""" - - -@dataclass(kw_only=True) -class ExtendedExtractState(ExtractionState): - tool_call_id: str = field(default="") - """The ID of the tool call to be patched.""" - bump_attempt: bool = field(default=False) - - -@dataclass(kw_only=True) -class DeletionState(ExtractionState): - deletion_target: str = field(default="") - - -class _ExtendedValidationNode(ValidationNode): - def __init__(self, *args, enable_deletes: bool = False, **kwargs): - super().__init__(*args, **kwargs) - self.enable_deletes = enable_deletes - - def _func(self, input: ExtractionState, config: RunnableConfig) -> Any: # type: ignore - """Validate and run tool calls synchronously.""" - output_type, message = self._get_message(asdict(input)) - removal_schema = None - if self.enable_deletes and input.existing: - removal_schema = _create_remove_doc_from_existing(input.existing) - - def run_one(call: ToolCall): - try: - if removal_schema and call["name"] == removal_schema.__name__: - schema = removal_schema - else: - schema = self.schemas_by_name[call["name"]] - output = schema.model_validate(call["args"]) - return ToolMessage( - content=output.model_dump_json(), - name=call["name"], - tool_call_id=cast(str, call["id"]), - ) - except KeyError: - valid_names = ", ".join(self.schemas_by_name.keys()) - return ToolMessage( - content=f'Unrecognized tool name: "{call["name"]}". You only have' - f" access to the following tools: {valid_names}." - " Please call PatchFunctionName with the *correct* tool name" - f" to fix json_doc_id=[{call['id']}].", - name=call["name"], - tool_call_id=cast(str, call["id"]), - status="error", - ) - except Exception as e: - return ToolMessage( - content=self._format_error(e, call, schema), - name=call["name"], - tool_call_id=cast(str, call["id"]), - status="error", - ) - - with get_executor_for_config(config) as executor: - outputs = [*executor.map(run_one, message.tool_calls)] - if output_type == "list": - return outputs - else: - return {"messages": outputs} - - -def _is_injected_arg_type(type_: Type) -> bool: - return any( - isinstance(arg, InjectedToolArg) - or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) - for arg in get_args(type_)[1:] - ) - - -def _curry(func: Callable, **fixed_kwargs: Any) -> Callable: - """Bind parameters to a function, removing those parameters from the signature. - - Useful for exposing a narrower interface than what the the original function - provides. - """ - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - new_kwargs = {**fixed_kwargs, **kwargs} - return func(*args, **new_kwargs) - - sig = inspect.signature(func) - # Check that fixed_kwargs are all valid parameters of the function - invalid_kwargs = set(fixed_kwargs) - set(sig.parameters) - if invalid_kwargs: - raise ValueError(f"Invalid parameters: {invalid_kwargs}") - - new_params = [p for name, p in sig.parameters.items() if name not in fixed_kwargs] - wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore - return wrapper - - -def _strip_injected(fn: Callable) -> Callable: - """Strip injected arguments from a function's signature.""" - injected = [ - p.name - for p in inspect.signature(fn).parameters.values() - if _is_injected_arg_type(p.annotation) - ] - return _curry(fn, **{k: None for k in injected}) - - -def _ensure_patches(args: dict) -> list[JsonPatch]: - patches = args.get("patches") - if isinstance(patches, list): - return patches - - if isinstance(patches, str): - try: - parsed = json.loads(patches) - if isinstance(parsed, list): - return parsed - except Exception: - pass - - bracket_depth = 0 - first_list_str = None - start = patches.find("[") - if start != -1: - for i in range(start, len(patches)): - if patches[i] == "[": - bracket_depth += 1 - elif patches[i] == "]": - bracket_depth -= 1 - if bracket_depth == 0: - first_list_str = patches[start : i + 1] - break - if first_list_str: - try: - parsed = json.loads(first_list_str) - if isinstance(parsed, list): - return parsed - except Exception: - pass - - return [] +from trustcall.tools import ensure_tools, _convert_any_typed_dicts_to_pydantic +from trustcall.types import ExtractionInputs, ExtractionOutputs, SchemaInstance +from trustcall.schema import _create_patch_doc_schema, _create_patch_function_errors_schema +from trustcall.extract import _Extract, _ExtractUpdates +from trustcall.patch import _Patch +# Create default versions of PatchDoc and PatchFunctionErrors for backward compatibility +PatchDoc = _create_patch_doc_schema(for_gemini=False) +PatchFunctionErrors = _create_patch_function_errors_schema(for_gemini=False) __all__ = [ "create_extractor", "ensure_tools", "ExtractionInputs", "ExtractionOutputs", -] + "ExtractionState", + "ExtendedExtractState", + "DeletionState", + "SchemaInstance", + "PatchDoc", + "PatchFunctionErrors", + "_ExtractUpdates", + "_Extract", + "_Patch", + "_convert_any_typed_dicts_to_pydantic", +] \ No newline at end of file diff --git a/trustcall/extract.py b/trustcall/extract.py new file mode 100644 index 0000000..5e1124d --- /dev/null +++ b/trustcall/extract.py @@ -0,0 +1,900 @@ +"""Extraction-related functionality for the trustcall package.""" + +from __future__ import annotations + +import functools +import logging +import operator +import uuid +from dataclasses import asdict +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Type, + Union, + cast, +) + +import jsonpatch # type: ignore[import-untyped] +import langsmith as ls +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AnyMessage, + HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.prompt_values import PromptValue +from langchain_core.runnables import Runnable, RunnableConfig +from langgraph.constants import Send +from langgraph.graph import StateGraph +from langgraph.utils.runnable import RunnableCallable +from pydantic import BaseModel +from typing_extensions import TypedDict + +from trustcall.patch import _Patch +from trustcall.schema import ( + _create_remove_doc_from_existing, + _get_schema, + _create_patch_function_errors_schema, + _create_patch_doc_schema, +) +from trustcall.tools import TOOL_T, ensure_tools +from trustcall.types import ( + ExistingType, + ExtractionInputs, + ExtractionOutputs, + InputsLike, + Messages, + SchemaInstance, +) +from trustcall.utils import is_gemini_model, _get_history_for_tool_call +from trustcall.validation import _ExtendedValidationNode +from trustcall.states import ExtractionState, ExtendedExtractState, DeletionState, MessageOp + +logger = logging.getLogger("extraction") + +DEFAULT_MAX_ATTEMPTS = 3 + +class _Extract: + def __init__( + self, + llm: BaseChatModel, + tools: Sequence, + tool_choice: Optional[str] = None, + for_gemini: bool = False, + ): + # Create proper tool schemas based on the model type + tool_schemas = [] + for t in tools: + schema = _get_schema(t, for_gemini) + tool_dict = { + "type": "function", + "function": { + "name": getattr(t, "name", t.__name__), + "description": t.__doc__, + "parameters": schema, + } + } + tool_schemas.append(tool_dict) + + self.bound_llm = llm.bind_tools(tool_schemas, tool_choice=tool_choice) + + @ls.traceable + def _tear_down(self, msg: AIMessage) -> dict: + if not msg.id: + msg.id = str(uuid.uuid4()) + return { + "messages": [msg], + "attempts": 1, + "msg_id": msg.id, + } + + async def ainvoke(self, state: ExtractionState, config: RunnableConfig) -> dict: + """Extract entities from the input messages.""" + msg = await self.bound_llm.ainvoke(state.messages, config) + return self._tear_down(cast(AIMessage, msg)) + + def invoke(self, state: ExtractionState, config: RunnableConfig) -> dict: + """Extract entities from the input messages.""" + msg = self.bound_llm.invoke(state.messages, config) + return self._tear_down(msg) + + def as_runnable(self): + return RunnableCallable(self.invoke, self.ainvoke, name="extract", trace=False) + + +class _ExtractUpdates: + """Prompt an LLM to patch an existing schema. + + We have found this to be prefereable to re-generating + the entire tool call from scratch in several ways: + + 1. Fewer output tokens. + 2. Less likely to introduce new errors or drop important information. + 3. Easier for the LLM to generate. + """ + + def __init__( + self, + llm: BaseChatModel, + tools: Dict[str, Type[BaseModel]], + enable_inserts: bool = False, + enable_updates: bool = True, + enable_deletes: bool = False, + existing_schema_policy: bool | Literal["ignore"] = True, + ): + if not any((enable_inserts, enable_updates, enable_deletes)): + raise ValueError( + "At least one of enable_inserts, enable_updates," + " or enable_deletes must be True." + ) + + # Get the appropriate patching tools - Gemini supports simpler JSON schemas, so requires different tools + using_gemini = is_gemini_model(llm) + patch_doc = _create_patch_doc_schema(using_gemini) + patch_function_errors = _create_patch_function_errors_schema(using_gemini) + + new_tools: list = [patch_doc] if enable_updates else [] + tool_choice = "PatchDoc" if not enable_deletes else "any" + if enable_inserts: + tools_ = [ + schema + for name, schema in (tools or {}).items() + if name not in {patch_doc.__name__, patch_function_errors.__name__} + ] + new_tools.extend(tools_) + tool_choice = "any" + + self.enable_inserts = enable_inserts + self.enable_updates = enable_updates + self.bound_tools = new_tools + self.tool_choice = tool_choice + self.bound = llm.bind_tools(new_tools, tool_choice=tool_choice) + self.enable_deletes = enable_deletes + self.tools = dict(tools) | {schema_.__name__: schema_ for schema_ in new_tools} + self.existing_schema_policy = existing_schema_policy + self.using_gemini = using_gemini + + + @ls.traceable(tags=["langsmith:hidden"]) + def _setup(self, state: ExtractionState): + messages = state.messages + existing = state.existing + if not existing: + raise ValueError("No existing schemas provided.") + existing = self._validate_existing(existing) # type: ignore[assignment] + schema_strings = [] + if isinstance(existing, dict): + for k, v in existing.items(): + if k not in self.tools and self.existing_schema_policy is False: + schema_str = "object" + else: + schema = self.tools[k] + schema_json = _get_schema(schema, self.using_gemini) + schema_str = f""" + + {schema_json} + +""" + schema_strings.append( + f"\n\n{v}\n" + f"{schema_str}" + ) + else: + for schema_id, tname, d in existing: + schema_strings.append( + f'\n{d}\n' + ) + + existing_schemas = "\n".join(schema_strings) + cmd = "Generate JSONPatches to update the existing schema instances." + if self.enable_inserts: + cmd += ( + " If you need to extract or insert *new* instances of the schemas" + ", call the relevant function(s)." + ) + + existing_msg = f"""{cmd} + +{existing_schemas} + +""" + if isinstance(messages[0], SystemMessage): + system_message = messages.pop(0) + if isinstance(system_message.content, str): + system_message.content += "\n\n" + existing_msg + else: + system_message.content = cast(list, system_message.content) + [ + "\n\n" + existing_msg + ] + else: + system_message = SystemMessage(content=existing_msg) + removal_schema = None + if self.enable_deletes and existing: + removal_schema = _create_remove_doc_from_existing(existing) + bound_model = self.bound.bound.bind_tools( # type: ignore + self.bound_tools + [removal_schema], + tool_choice=self.tool_choice, + ) + else: + bound_model = self.bound + + return [system_message] + messages, existing, removal_schema, bound_model + + @ls.traceable(tags=["langsmith:hidden"]) + def _teardown( + self, + msg: AIMessage, + existing: Union[Dict[str, Any], List[Any]], + ): + resolved_tool_calls = [] + updated_docs = {} + + # Try to get trace ID from langfuse if available, otherwise continue without it + try: + from langfuse.decorators import langfuse_context + rt = langfuse_context.get_current_trace_id() + except (ImportError, AttributeError): + # Langfuse not available, try langsmith + try: + rt = ls.get_current_run_tree() + except (ImportError, AttributeError): + # Neither available, continue without tracing + pass + + for tc in msg.tool_calls: + if tc["name"] == "PatchDoc": + json_doc_id = tc["args"]["json_doc_id"] + if isinstance(existing, dict): + target = existing.get(str(json_doc_id)) + tool_name = json_doc_id + else: + try: + _, tool_name, target = next( + (e for e in existing if e[0] == json_doc_id), + ) + if not tool_name: + raise ValueError( + "Could not find tool name " + f"for json_doc_id {json_doc_id}" + ) + except StopIteration: + logger.error( + f"Could not find existing schema in dict for {json_doc_id}" + ) + if rt: + rt.error = ( + f"Could not find existing schema for {json_doc_id}" + ) + continue + except (ValueError, IndexError, TypeError): + logger.error( + f"Could not find existing schema in list for {json_doc_id}" + ) + if rt: + rt.error = ( + f"Could not find existing schema for {json_doc_id}" + ) + continue + + if target: + try: + from trustcall.schema import _ensure_patches + patches = _ensure_patches(tc["args"]) + if patches or self.tool_choice == "PatchDoc": + # The second condition is so that, when we are continuously + # updating a single doc, we will still include it in + # the output responses list; mainly for backwards + # compatibility + resolved_tool_calls.append( + ToolCall( + id=tc["id"], + name=tool_name, + args=jsonpatch.apply_patch(target, patches), + ) + ) + updated_docs[tc["id"]] = str(json_doc_id) + except Exception as e: + logger.error(f"Could not apply patch: {e}") + if rt: + rt.error = f"Could not apply patch: {repr(e)}" + else: + if rt: + rt.error = f"Could not find existing schema for {tool_name}" + logger.warning(f"Could not find existing schema for {tool_name}") + else: + resolved_tool_calls.append(tc) + ai_message = AIMessage( + content=msg.content, + tool_calls=resolved_tool_calls, + additional_kwargs={"updated_docs": updated_docs}, + ) + if not ai_message.id: + ai_message.id = str(uuid.uuid4()) + + return { + "messages": [ai_message], + "attempts": 1, + "msg_id": ai_message.id, + } + + @property + def _provided_tools(self): + return sorted(self.tools.keys() - {"PatchDoc", "PatchFunctionErrors"}) + + def _validate_existing( + self, existing: ExistingType + ) -> Union[Dict[str, Any], List[Any]]: + """Check that all existing schemas match a known schema or '__any__'.""" + if isinstance(existing, dict): + # For each top-level key, see if it's recognized + validated = {} + for key, record in existing.items(): + if key in self.tools or key == "__any__": + validated[key] = record + else: + # Key does not match known schema + if self.existing_schema_policy is True: + raise ValueError( + f"Key '{key}' doesn't match any schema. " + f"Known schemas: {list(self.tools.keys())}" + ) + elif self.existing_schema_policy is False: + validated[key] = record + else: # "ignore" + logger.warning(f"Ignoring unknown schema: {key}") + return validated + + elif isinstance(existing, list): + # For list types, validate each item's schema_name + coerced = [] + for i, item in enumerate(existing): + if hasattr(item, "record_id") and hasattr(item, "schema_name") and hasattr(item, "record"): + if ( + item.schema_name not in self.tools + and item.schema_name != "__any__" + ): + if self.existing_schema_policy is True: + raise ValueError( + f"Unknown schema '{item.schema_name}' at index {i}" + ) + elif self.existing_schema_policy is False: + coerced.append( + SchemaInstance( + item.record_id, item.schema_name, item.record + ) + ) + else: # "ignore" + logger.warning(f"Ignoring unknown schema at index {i}") + continue + else: + coerced.append(item) + elif isinstance(item, tuple) and len(item) == 3: + record_id, schema_name, record_dict = item + if isinstance(record_dict, BaseModel): + record_dict = record_dict.model_dump(mode="json") + if schema_name not in self.tools and schema_name != "__any__": + if self.existing_schema_policy is True: + raise ValueError( + f"Unknown schema '{schema_name}' at index {i}" + ) + elif self.existing_schema_policy is False: + coerced.append( + SchemaInstance(record_id, schema_name, record_dict) + ) + else: # "ignore" + logger.warning(f"Ignoring unknown schema '{schema_name}'") + continue + else: + coerced.append( + SchemaInstance(record_id, schema_name, record_dict) + ) + elif isinstance(item, tuple) and len(item) == 2: + # Assume record_ID, item + record_id, model = item + if hasattr(model, "__name__"): + schema_name = model.__name__ + else: + schema_name = model.__repr_name__() + + if schema_name not in self.tools and schema_name != "__any__": + if self.existing_schema_policy is True: + raise ValueError( + f"Unknown schema '{schema_name}' at index {i}" + ) + elif self.existing_schema_policy is False: + val = ( + model.model_dump(mode="json") + if isinstance(model, BaseModel) + else model + ) + coerced.append(SchemaInstance(record_id, schema_name, val)) + else: # "ignore" + logger.warning(f"Ignoring unknown schema '{schema_name}'") + continue + else: + val = ( + model.model_dump(mode="json") + if isinstance(model, BaseModel) + else model + ) + coerced.append(SchemaInstance(record_id, schema_name, val)) + elif isinstance(item, BaseModel): + if hasattr(item, "__name__"): + schema_name = item.__name__ + else: + schema_name = item.__repr_name__() + + if schema_name not in self.tools and schema_name != "__any__": + if self.existing_schema_policy is True: + raise ValueError( + f"Unknown schema '{schema_name}' at index {i}" + ) + elif self.existing_schema_policy is False: + coerced.append( + SchemaInstance( + str(uuid.uuid4()), + schema_name, + item.model_dump(mode="json"), + ) + ) + else: # "ignore" + logger.warning(f"Ignoring unknown schema '{schema_name}'") + continue + else: + coerced.append( + SchemaInstance( + str(uuid.uuid4()), + schema_name, + item.model_dump(mode="json"), + ) + ) + else: + raise ValueError( + f"Invalid item at index {i} in existing list." + f" Provided: {item}, Expected: SchemaInstance" + f" or Tuple[str, str, dict] or BaseModel" + ) + return coerced + else: + raise ValueError( + f"Invalid type for existing. Provided: {type(existing)}," + f" Expected: dict or list. Supported formats are:\n" + "1. Dict[str, Any] where keys are tool names\n" + "2. List[SchemaInstance]\n3. List[Tuple[str, str, Dict[str, Any]]]" + ) + + async def ainvoke(self, state: ExtractionState, config: RunnableConfig) -> dict: + """Generate a JSONPatch to simply update an existing schema. + + Returns a single AIMessage with the updated schema, as if + the schema were extracted from scratch. + """ + messages, existing, removal_schema, bound_model = self._setup(state) + try: + msg = await bound_model.ainvoke(messages, config) + return { + **self._teardown(cast(AIMessage, msg), existing), + "removal_schema": removal_schema, + } + except Exception as e: + return { + "messages": [ + HumanMessage( + content="Fix the validation error while" + f" also avoiding: {repr(str(e))}" + ) + ], + "attempts": 1, + } + + def invoke(self, state: ExtractionState, config: RunnableConfig) -> dict: + messages, existing, removal_schema, bound_model = self._setup(state) + try: + msg = bound_model.invoke(messages, config) + return {**self._teardown(msg, existing), "removal_schema": removal_schema} + except Exception as e: + return { + "messages": [ + HumanMessage( + content="Fix the validation error while" + f" also avoiding: {repr(str(e))}" + ) + ], + "attempts": 1, + } + + def as_runnable(self): + return RunnableCallable( + self.invoke, self.ainvoke, name="extract_updates", trace=False + ) + + +def create_extractor( + llm: str | BaseChatModel, + *, + tools: Sequence[TOOL_T], + tool_choice: Optional[str] = None, + enable_inserts: bool = False, + enable_updates: bool = True, + enable_deletes: bool = False, + existing_schema_policy: bool | Literal["ignore"] = True, +) -> Runnable[InputsLike, ExtractionOutputs]: + """Create an extractor that generates validated structured outputs using an LLM. + + This function binds validators and retry logic to ensure the validity of + generated tool calls. It uses JSONPatch to correct validation errors caused + by incorrect or incomplete parameters in previous tool calls. + + Args: + llm (BaseChatModel): The language model that will generate the initial + messages and fallbacks. + tools (Sequence[TOOL_T]): The tools to bind to the LLM. Can be BaseTool, + Type[BaseModel], Callable, or Dict[str, Any]. + tool_choice (Optional[str]): The specific tool to use. If None, + the LLM chooses whether to use (or not use) a tool based + on the input messages. (default: None) + enable_inserts (bool): Whether to allow the LLM to extract new schemas + even if it receives existing schemas. (default: False) + enable_updates (bool): Whether to allow the LLM to update existing schemas + using the PatchDoc tool. (default: True) + enable_deletes (bool): Whether to allow the LLM to delete existing schemas + using the RemoveDoc tool. (default: False) + existing_schema_policy (bool | Literal["ignore"]): How to handle existing schemas + that don't match the provided tool. Useful for migrating or managing heterogenous + docs. (default: True) True means raise error. False means treat as dict. + "ignore" means ignore (drop any attempts to patch these) + + Returns: + Runnable[ExtractionInputs, ExtractionOutputs]: A runnable that + can be invoked with a list of messages and returns validated AI + messages and responses. + + Examples: + >>> from langchain_fireworks import ( + ... ChatFireworks, + ... ) + >>> from pydantic import ( + ... BaseModel, + ... Field, + ... ) + >>> + >>> class UserInfo(BaseModel): + ... name: str = Field(description="User's full name") + ... age: int = Field(description="User's age in years") + >>> + >>> llm = ChatFireworks(model="accounts/fireworks/models/firefunction-v2") + >>> extractor = create_extractor( + ... llm, + ... tools=[UserInfo], + ... ) + >>> result = extractor.invoke( + ... { + ... "messages": [ + ... ( + ... "human", + ... "My name is Alice and I'm 30 years old", + ... ) + ... ] + ... } + ... ) + >>> result["responses"][0] + UserInfo(name='Alice', age=30) + + Using multiple tools + >>> from typing import ( + ... List, + ... ) + >>> + >>> class Preferences(BaseModel): + ... foods: List[str] = Field(description="Favorite foods") + >>> + >>> extractor = create_extractor( + ... llm, + ... tools=[ + ... UserInfo, + ... Preferences, + ... ], + ... ) + >>> result = extractor.invoke( + ... { + ... "messages": [ + ... ( + ... "system", + ... "Extract all the user's information and preferences" + ... "from the conversation below using parallel tool calling.", + ... ), + ... ( + ... "human", + ... "I'm Bob, 25 years old, and I love pizza and sushi", + ... ), + ... ] + ... } + ... ) + >>> print(result["responses"]) + [UserInfo(name='Bob', age=25), Preferences(foods=['pizza', 'sushi'])] + >>> print(result["messages"]) # doctest: +SKIP + [ + AIMessage( + content='', tool_calls=[ + ToolCall(id='...', name='UserInfo', args={'name': 'Bob', 'age': 25}), + ToolCall(id='...', name='Preferences', args={'foods': ['pizza', 'sushi']} + )] + ) + ] + + Updating an existing schema: + >>> existing = { + ... "UserInfo": { + ... "name": "Alice", + ... "age": 30, + ... }, + ... "Preferences": { + ... "foods": [ + ... "pizza", + ... "sushi", + ... ] + ... }, + ... } + >>> extractor = create_extractor( + ... llm, + ... tools=[ + ... UserInfo, + ... Preferences, + ... ], + ... ) + >>> result = extractor.invoke( + ... { + ... "messages": [ + ... ( + ... "system", + ... "You are tasked with maintaining user info and preferences." + ... " Use the tools to update the schemas.", + ... ), + ... ( + ... "human", + ... "I'm Alice; just had my 31st birthday yesterday." + ... " We had spinach, which is my FAVORITE!", + ... ), + ... ], + ... "existing": existing, + ... } + ... ) + """ # noqa + # Convert string to model if needed + if isinstance(llm, str): + try: + from langchain.chat_models import init_chat_model + llm = init_chat_model(llm) + except ImportError: + raise ImportError( + "Creating extractors from a string requires langchain>=0.3.0," + " as well as the provider-specific package" + " (like langchain-openai, langchain-anthropic, etc.)" + " Please install langchain to continue." + ) + builder = StateGraph(ExtractionState) + + # Check if the model is a Gemini model - this affects the schema generation and patching + using_gemini = is_gemini_model(llm) + + # Define error formatting + def format_exception(error: BaseException, call: ToolCall, schema: Type[BaseModel]) -> str: + return ( + f"Error:\n\n```\n{str(error)}\n```\n" + "Expected Parameter Schema:\n\n" + f"```json\n{_get_schema(schema, using_gemini)}\n```\n" + f"Please use PatchFunctionErrors to fix all validation errors." + f" for json_doc_id=[{call['id']}]." + ) + + # Get the appropriate patching tools - Gemini supports simpler JSON schemas, so requires different tools + patch_doc = _create_patch_doc_schema(using_gemini) + patch_function_errors = _create_patch_function_errors_schema(using_gemini) + + # Create validator with appropriate tools + validator = _ExtendedValidationNode( + ensure_tools(tools) + [patch_doc, patch_function_errors], + format_error=format_exception, # type: ignore + enable_deletes=enable_deletes, + ) + _extract_tools = [ + schema + for name, schema in validator.schemas_by_name.items() + if name not in {patch_doc.__name__, patch_function_errors.__name__} + ] + tool_names = [getattr(t, "name", t.__name__) for t in _extract_tools] + builder.add_node( + _Extract( + llm, + _extract_tools, + tool_choice, + for_gemini=using_gemini, + ).as_runnable() + ) + updater = _ExtractUpdates( + llm, + tools=validator.schemas_by_name.copy(), + enable_inserts=enable_inserts, # type: ignore + enable_updates=enable_updates, # type: ignore + enable_deletes=enable_deletes, # type: ignore + existing_schema_policy=existing_schema_policy, + ) + builder.add_node(updater.as_runnable()) + builder.add_node(_Patch(llm, valid_tool_names=tool_names).as_runnable()) + builder.add_node("validate", validator) + + def del_tool_call(state: DeletionState) -> dict: + return { + "messages": MessageOp(op="delete", target=state.deletion_target), + } + + builder.add_node(del_tool_call) + + def enter(state: ExtractionState) -> Literal["extract", "extract_updates"]: + if state.existing: + return "extract_updates" + return "extract" + + builder.add_conditional_edges("__start__", enter) + + def validate_or_retry( + state: ExtractionState, + ) -> Literal["validate", "extract_updates"]: + if state.messages[-1].type == "ai": + return "validate" + return "extract_updates" + + builder.add_edge("extract", "validate") + builder.add_conditional_edges("extract_updates", validate_or_retry) + + def handle_retries( + state: ExtractionState, config: RunnableConfig + ) -> Union[Literal["__end__"], list]: + """After validation, decide whether to retry or end the process.""" + max_attempts = config["configurable"].get("max_attempts", DEFAULT_MAX_ATTEMPTS) + if state.attempts >= max_attempts: + return "__end__" + # Only continue if we need to patch the tool call + to_send = [] + # We only increment the attempt count once, regardless of the fan-out + # degree. + bumped = False + for m in reversed(state.messages): + if isinstance(m, AIMessage): + break + if isinstance(m, ToolMessage): + if m.status == "error": + # Each fallback will fix at most 1 schema per time. + messages_for_fixing = _get_history_for_tool_call( + state.messages, m.tool_call_id + ) + to_send.append( + Send( + "patch", + ExtendedExtractState( + **{ + **asdict(state), + "messages": messages_for_fixing, + "tool_call_id": m.tool_call_id, + "bump_attempt": not bumped, + } + ), + ) + ) + bumped = True + else: + # We want to delete the validation tool calls + # anyway to avoid mixing branches during fan-in + to_send.append( + Send( + "del_tool_call", + DeletionState( + deletion_target=str(m.id), messages=state.messages + ), + ) + ) + return to_send + + builder.add_conditional_edges( + "validate", handle_retries, path_map=["__end__", "patch", "del_tool_call"] + ) + + def sync(state: ExtractionState, config: RunnableConfig) -> dict: + return {"messages": []} + + def validate_or_repatch( + state: ExtractionState, + ) -> Literal["validate", "patch"]: + if state.messages[-1].type == "ai": + return "validate" + return "patch" + + builder.add_node(sync) + + builder.add_conditional_edges( + "sync", validate_or_repatch, path_map=["validate", "patch", "__end__"] + ) + compiled = builder.compile(checkpointer=False) + compiled.name = "TrustCall" + + def filter_state(state: dict) -> ExtractionOutputs: + """Filter the state to only include the validated AIMessage + responses.""" + msg_id = state["msg_id"] + msg: Optional[AIMessage] = next( + ( + m + for m in state["messages"] + if m.id == msg_id and isinstance(m, AIMessage) + ), # type: ignore + None, + ) + if not msg: + return ExtractionOutputs( + messages=[], + responses=[], + attempts=state["attempts"], + response_metadata=[], + ) + responses = [] + response_metadata = [] + updated_docs = msg.additional_kwargs.get("updated_docs") or {} + existing = state.get("existing") + removal_schema = None + if enable_deletes and existing: + removal_schema = _create_remove_doc_from_existing(existing) + for tc in msg.tool_calls: + if removal_schema and tc["name"] == removal_schema.__name__: + sch = removal_schema + elif tc["name"] not in validator.schemas_by_name: + if existing_schema_policy in (False, "ignore"): + continue + sch = validator.schemas_by_name[tc["name"]] + else: + sch = validator.schemas_by_name[tc["name"]] + try: + responses.append( + sch.model_validate(tc["args"]) + if hasattr(sch, "model_validate") + else sch.parse_obj(tc["args"]) + ) + meta = {"id": tc["id"]} + if json_doc_id := updated_docs.get(tc["id"]): + meta["json_doc_id"] = json_doc_id + response_metadata.append(meta) + except Exception as e: + logger.error(e) + continue + + return { + "messages": [msg], + "responses": responses, + "response_metadata": response_metadata, + "attempts": state["attempts"], + } + + def coerce_inputs(state: InputsLike) -> Union[ExtractionInputs, dict]: + """Coerce inputs to the expected format.""" + if isinstance(state, list): + return {"messages": state} + if isinstance(state, str): + return {"messages": [{"role": "user", "content": state}]} + if isinstance(state, PromptValue): + return {"messages": state.to_messages()} + if isinstance(state, dict): + if isinstance(state.get("messages"), PromptValue): + state = {**state, "messages": state["messages"].to_messages()} # type: ignore + else: + if hasattr(state, "messages"): + state = {"messages": state.messages.to_messages()} # type: ignore + + return cast(dict, state) + + return coerce_inputs | compiled | filter_state \ No newline at end of file diff --git a/trustcall/patch.py b/trustcall/patch.py new file mode 100644 index 0000000..a0c34ec --- /dev/null +++ b/trustcall/patch.py @@ -0,0 +1,181 @@ +"""Patching-related functionality for the trustcall package.""" + +from __future__ import annotations + +import logging +import uuid +from typing import ( + Any, + Dict, + List, + Literal, + Sequence, + Union, + Optional, + cast, + +) + +import jsonpatch # type: ignore[import-untyped] +import langsmith as ls +from langchain_core.messages import ( + AIMessage, + AnyMessage, + ToolCall, + ToolMessage, +) +from langchain_core.runnables import RunnableConfig +from langgraph.constants import Send +from langgraph.types import Command +from langgraph.utils.runnable import RunnableCallable + +from trustcall.schema import _ensure_patches, _create_patch_function_errors_schema, _create_patch_function_name_schema +from trustcall.states import ExtendedExtractState, MessageOp +from trustcall.utils import is_gemini_model +from langchain_core.language_models import BaseChatModel + +logger = logging.getLogger("extraction") + + +class _Patch: + """Prompt an LLM to patch an invalid schema after it receives a ValidationError. + + We have found this to be more reliable and more token-efficient than + re-creating the entire tool call from scratch. + """ + + def __init__( + self, llm: BaseChatModel, valid_tool_names: Optional[List[str]] = None + ): + # Get the appropriate patching tools based on LLM type + using_gemini = is_gemini_model(llm) + self.bound = llm.bind_tools( + [ + _create_patch_function_errors_schema(using_gemini), + _create_patch_function_name_schema(valid_tool_names, using_gemini) + ], + tool_choice="any", + ) + + @ls.traceable(tags=["patch", "langsmith:hidden"]) + def _tear_down( + self, + msg: AIMessage, + messages: List[AnyMessage], + target_id: str, + bump_attempt: bool, + ): + if not msg.id: + msg.id = str(uuid.uuid4()) + # We will directly update the messages in the state before validation. + msg_ops = _infer_patch_message_ops(messages, msg.tool_calls, target_id) + return { + "messages": msg_ops, + "attempts": 1 if bump_attempt else 0, + } + + async def ainvoke( + self, state: ExtendedExtractState, config: RunnableConfig + ) -> Command[Literal["sync", "__end__"]]: + """Generate a JSONPatch to correct the validation error and heal the tool call. + + Assumptions: + - We only support a single tool call to be patched. + - State's message list's last AIMessage contains the actual schema to fix. + - The last ToolMessage contains the tool call to fix. + + """ + try: + msg = await self.bound.ainvoke(state.messages, config) + except Exception: + return Command(goto="__end__") + return Command( + update=self._tear_down( + cast(AIMessage, msg), + state.messages, + state.tool_call_id, + state.bump_attempt, + ), + goto=("sync",), + ) + + def invoke( + self, state: ExtendedExtractState, config: RunnableConfig + ) -> Command[Literal["sync", "__end__"]]: + try: + msg = self.bound.invoke(state.messages, config) + except Exception: + return Command(goto="__end__") + return Command( + update=self._tear_down( + cast(AIMessage, msg), + state.messages, + state.tool_call_id, + state.bump_attempt, + ), + goto=("sync",), + ) + + def as_runnable(self): + return RunnableCallable(self.invoke, self.ainvoke, name="patch", trace=False) + + +def _get_message_op( + messages: Sequence[AnyMessage], tool_call: dict, tool_call_name: str, target_id: str +) -> List[MessageOp]: + msg_ops: List[MessageOp] = [] + + # Process each message + for m in messages: + if isinstance(m, AIMessage): + for tc in m.tool_calls: + if tc["id"] == target_id: + # Handle PatchFunctionName + if tool_call_name == "PatchFunctionName": + if not tool_call.get("fixed_name"): + continue + msg_ops.append({ + "op": "update_tool_name", + "target": { + "id": target_id, + "name": str(tool_call["fixed_name"]), + }, + }) + # Handle any patch function - cover all cases using name check instead of type check + elif "PatchFunctionErrors" in tool_call_name or tool_call_name == "PatchDoc": + try: + patches = _ensure_patches(tool_call) + if patches: + patched_args = jsonpatch.apply_patch(tc["args"], patches) + msg_ops.append({ + "op": "update_tool_call", + "target": { + "id": target_id, + "name": tc["name"], + "args": patched_args, + }, + }) + except Exception as e: + logger.error(f"Could not apply patch: {repr(e)}") + else: + logger.error(f"Unrecognized function call {tool_call_name}") + + # Add delete operations for tool messages + if isinstance(m, ToolMessage) and m.tool_call_id == target_id: + msg_ops.append(MessageOp(op="delete", target=m.id or "")) + + return msg_ops + + +@ls.traceable(tags=["langsmith:hidden"]) +def _infer_patch_message_ops( + messages: Sequence[AnyMessage], tool_calls: List[ToolCall], target_id: str +): + ops = [ + op + for tool_call in tool_calls + for op in _get_message_op( + messages, tool_call["args"], tool_call["name"], target_id=target_id + ) + ] + return ops \ No newline at end of file diff --git a/trustcall/schema.py b/trustcall/schema.py new file mode 100644 index 0000000..0b75d64 --- /dev/null +++ b/trustcall/schema.py @@ -0,0 +1,524 @@ +""" +Handles the creation, conversion, and management of schemas used for tool calling, +validation, and patching, ensuring that trustcall can work with different +LLMs (including Gemini) and various schema formats. +""" + +from __future__ import annotations + +import functools +import json +import logging +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Type, + Union, + get_args, + get_origin, +) + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictBool, + StrictFloat, + StrictInt, + field_validator, +) + +from trustcall.utils import _exclude_none + +logger = logging.getLogger("extraction") + + +def create_gemini_compatible_schema(model_class): + """ + Create a Gemini-compatible schema from a Pydantic model. + + Args: + model_class: The Pydantic model class + + Returns: + A Gemini-compatible schema dictionary + """ + # Start with basic model info + gemini_schema = { + "type": "OBJECT", + "title": model_class.__name__, + "description": model_class.__doc__ or f"A {model_class.__name__} object", + "properties": {}, + "required": [] + } + + # Get the field names in the order they were defined + # This is crucial for Gemini's expected property ordering + field_names = list(model_class.model_fields.keys()) + + # Process all model fields + for field_name in field_names: + field = model_class.model_fields[field_name] + + # Add to required list if appropriate + if field.is_required(): + gemini_schema["required"].append(field_name) + + # Get field description + field_desc = field.description or f"The {field_name} field" + + # Convert field type to Gemini format + gemini_schema["properties"][field_name] = convert_field_to_gemini(field, field_desc) + + return gemini_schema + + +def convert_field_to_gemini(field, description): + """Convert a Pydantic field to Gemini-compatible schema format.""" + annotation = field.annotation + + # Handle basic types + if annotation is str: + return {"type": "STRING", "description": description} + elif annotation is int: + return {"type": "INTEGER", "description": description} + elif annotation is float: + return {"type": "NUMBER", "description": description} + elif annotation is bool: + return {"type": "BOOLEAN", "description": description} + + # Handle container types + origin = get_origin(annotation) + if origin is list: + item_type = get_args(annotation)[0] + return { + "type": "ARRAY", + "description": description, + "items": convert_type_to_gemini(item_type) + } + elif origin is dict: + return {"type": "OBJECT", "description": description} + elif origin is Union or origin is Optional: + # For Union/Optional types, use the first type as primary + # and add nullable if None is an option + types = get_args(annotation) + primary_type = next((t for t in types if t is not type(None)), types[0]) + result = convert_type_to_gemini(primary_type) + if type(None) in types: + # Unlike JSON Schema which uses nullable, Gemini might need a different approach + # Just adding nullable true seems most reasonable + result["nullable"] = True + return result + + # Handle nested Pydantic models + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + # For nested models, recursively generate the schema + # Gemini may not support references, so include the full schema + return create_gemini_compatible_schema(annotation) + + # Default to string for unknown types + return {"type": "STRING", "description": description} + + +def convert_type_to_gemini(type_annotation): + """Convert a Python type to a Gemini schema type definition.""" + if type_annotation is str: + return {"type": "STRING"} + elif type_annotation is int: + return {"type": "INTEGER"} + elif type_annotation is float: + return {"type": "NUMBER"} + elif type_annotation is bool: + return {"type": "BOOLEAN"} + elif isinstance(type_annotation, type) and issubclass(type_annotation, BaseModel): + return create_gemini_compatible_schema(type_annotation) + + # Default to string + return {"type": "STRING"} + + +def _get_schema(model: Type[BaseModel], for_gemini: bool) -> dict: + if for_gemini: + return create_gemini_compatible_schema(model) + else: + if hasattr(model, "model_json_schema"): + schema = model.model_json_schema() + else: + schema = model.schema() # type: ignore + return _exclude_none(schema) + + +# JSON Patch related classes + +# We COULD just say Any for the value below, but Fireworks and some other +# providers don't support untyped arrays and dicts... +_JSON_PRIM_TYPES = Union[str, StrictInt, StrictBool, StrictFloat, None] +_JSON_TYPES = Union[ + _JSON_PRIM_TYPES, List[_JSON_PRIM_TYPES], Dict[str, _JSON_PRIM_TYPES] +] + + +class BasePatch(BaseModel): + """Base class for all patch types.""" + op: Literal["add", "remove", "replace"] = Field( + ..., + description="A JSON Pointer path that references a location within the" + " target document where the operation is performed." + " Note: patches are applied sequentially. If you remove a value, the collection" + " size changes before the next patch is applied.", + ) + path: str = Field( + ..., + description="A JSON Pointer path that references a location within the" + " target document where the operation is performed." + " Note: patches are applied sequentially. If you remove a value, the collection" + " size changes before the next patch is applied.", + ) + + +class FullPatch(BasePatch): + """A JSON Patch document represents an operation to be performed on a JSON document. + + Note that the op and path are ALWAYS required. Value is required for ALL operations except 'remove'. + This supports OpenAI and other LLMs with full JSON support (not Gemini). + """ # noqa + value: Union[_JSON_TYPES, List[_JSON_TYPES], Dict[str, _JSON_TYPES]] = Field( + ..., + description="The value to be used within the operation." + ) + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "op": "replace", + "path": "/path/to/my_array/1", + "value": "the newer value to be patched", + }, + { + "op": "replace", + "path": "/path/to/broken_object", + "value": {"new": "object"}, + }, + { + "op": "add", + "path": "/path/to/my_array/-", + "value": ["some", "values"], + }, + { + "op": "add", + "path": "/path/to/my_array/-", + "value": ["newer"], + }, + { + "op": "remove", + "path": "/path/to/my_array/1", + }, + ] + } + ) + +class GeminiJsonPatch(BasePatch): + """A JSON Patch document represents an operation to be performed on a JSON document. + + Note that the op and path are ALWAYS required. Value is required for ALL operations except 'remove'. + This supports Gemini with it's more limited JSON compatibility. + """ # noqa + # Similar to JsonPatch but with Gemini-compatible schema definition + # Instead of using a string-only value, use Union types that match Gemini's schema + value: Optional[Union[str, int, float, bool, List, Dict]] = Field( + default=None, + description="The value to be used within the operation. Required for" + " 'add' and 'replace' operations, not needed for 'remove'." + ) + + # For Gemini, we'll use a string value but with clear documentation that it can be complex + value: Optional[str] = Field( + default=None, + description="The value to be used within the operation. For complex values (objects, arrays), " + "provide valid JSON as a string. Required for 'add' and 'replace' operations." + ) + + @field_validator('value') + @classmethod + def validate_value(cls, v, info): + """Automatically convert complex values to JSON strings and handle remove operations.""" + values = info.data + + # Allow None for remove operations + if v is None and values.get("op") == "remove": + return v + + # Convert objects and arrays to JSON strings + if isinstance(v, (dict, list)): + return json.dumps(v) + + # Convert primitive types to strings + if v is not None and not isinstance(v, str): + return str(v) + + return v + + model_config = ConfigDict( + json_schema_extra={ + "type": "OBJECT", + "properties": { + "op": { + "type": "STRING", + "enum": ["add", "remove", "replace"], + "description": "The operation to be performed." + }, + "path": { + "type": "STRING", + "description": "JSON Pointer path where the operation is performed." + }, + "value": { + "type": "STRING", + "description": "Value to use in the operation. For complex values, use JSON strings." + } + }, + "required": ["op", "path"] + } + ) + +def get_patch_class(for_gemini: bool) -> Type[BasePatch]: + """Return the appropriate patch class based on the LLM type.""" + return GeminiJsonPatch if for_gemini else FullPatch + +def _create_patch_function_errors_schema(for_gemini: bool = False) -> Type[BaseModel]: + """Create the appropriate PatchFunctionErrors model based on the LLM type.""" + # Choose the appropriate patch type + patch_class = get_patch_class(for_gemini) + + class PatchFunctionErrors(BaseModel): + """Respond with all JSONPatch operations required to update the previous invalid function call.""" + + json_doc_id: str = Field( + ..., + description="The ID of the function you are patching.", + ) + planned_edits: str = Field( + ..., + description="Write a bullet-point list of each ValidationError you encountered" + " and the corresponding JSONPatch operation needed to heal it." + " For each operation, write why your initial guess was incorrect, " + " citing the corresponding types(s) from the JSONSchema" + " that will be used the validate the resultant patched document." + " Think step-by-step to ensure no error is overlooked.", + ) + patches: list[patch_class] = Field( + ..., + description="A list of JSONPatch operations to be applied to the" + " previous tool call's response arguments. If none are required, return" + " an empty list. This field is REQUIRED." + " Multiple patches in the list are applied sequentially in the order provided," + " with each patch building upon the result of the previous one.", + ) + + return PatchFunctionErrors + +def _create_patch_doc_schema(for_gemini: bool = False) -> Type[BaseModel]: + """Create the appropriate PatchDoc model based on the LLM type.""" + + patch_class = get_patch_class(for_gemini) + + + class PatchDoc(BaseModel): + """Respond with JSONPatch operations to update the existing JSON document based on the provided text and schema.""" + + json_doc_id: str = Field( + ..., + description="The json_doc_id of the document you are patching.", + ) + planned_edits: str = Field( + ..., + description="Think step-by-step, reasoning over each required" + " update and the corresponding JSONPatch operation to accomplish it." + " Cite the fields in the JSONSchema you referenced in developing this plan." + " Address each path as a group; don't switch between paths.\n" + " Plan your patches in the following order:" + "1. replace - this keeps collection size the same.\n" + "2. remove - BE CAREFUL ABOUT ORDER OF OPERATIONS." + " Each operation is applied sequentially." + " For arrays, remove the highest indexed value first to avoid shifting" + " indices. This ensures subsequent remove operations remain valid.\n" + " 3. add (for arrays, use /- to efficiently append to end).", + ) + # For Gemini, we use a list of simple dictionaries instead of complex models + patches: List[patch_class] = Field( + ..., + description="A list of JSONPatch operations to be applied to the" + " previous tool call's response arguments. If none are required, return" + " an empty list. This field is REQUIRED." + " Multiple patches in the list are applied sequentially in the order provided," + " with each patch building upon the result of the previous one." + " Take care to respect array bounds. Order patches as follows:\n" + " 1. replace - this keeps collection size the same\n" + " 2. remove - BE CAREFUL about order of operations. For arrays, remove" + " the highest indexed value first to avoid shifting indices.\n" + " 3. add - for arrays, use /- to efficiently append to end.", + ) + + return PatchDoc + +def _create_patch_function_name_schema(valid_tool_names: Optional[List[str]] = None, for_gemini: bool = False): + if valid_tool_names: + namestr = ", ".join(valid_tool_names) + vname = f" Must be one of {namestr}" + else: + vname = "" + + class PatchFunctionName(BaseModel): + """Call this if the tool message indicates that you previously invoked an invalid tool, (e.g., "Unrecognized tool name" error), do so here.""" # noqa + + json_doc_id: str = Field( + ..., + description="The ID of the function you are patching.", + ) + reasoning: list[str] = Field( + ..., + description="At least 2 logical reasons why this action ought to be taken." + "Cite the specific error(s) mentioned to motivate the fix.", + ) + fixed_name: Optional[str] = Field( + ..., + description="If you need to change the name of the function (e.g., " + f'from an "Unrecognized tool name" error), do so here.{vname}', + ) + + # If using Gemini, ensure the schema is Gemini-compatible + if for_gemini: + # Set a Gemini-compatible schema for the model + PatchFunctionName.model_config = ConfigDict( + json_schema_extra={ + "type": "OBJECT", + "properties": { + "json_doc_id": { + "type": "STRING", + "description": "The ID of the function you are patching." + }, + "reasoning": { + "type": "ARRAY", + "items": {"type": "STRING"}, + "description": "At least 2 logical reasons why this action ought to be taken." + }, + "fixed_name": { + "type": "STRING", + "description": f"The corrected function name.{vname}" + } + }, + "required": ["json_doc_id", "reasoning"] + } + ) + + return PatchFunctionName + + +def _create_remove_doc_from_existing(existing: Union[dict, list]): + if isinstance(existing, dict): + existing_ids = set(existing) + else: + existing_ids = set() + for schema_id, *_ in existing: + existing_ids.add(schema_id) + return _create_remove_doc_schema(tuple(sorted(existing_ids))) + + +@functools.lru_cache(maxsize=10) +def _create_remove_doc_schema(allowed_ids: tuple[str]) -> Type[BaseModel]: + """Create a RemoveDoc schema that validates against a set of allowed IDs.""" + + class RemoveDoc(BaseModel): + """Use this tool to remove (delete) a doc by its ID.""" + + json_doc_id: str = Field( + ..., + description=f"ID of the document to remove. Must be one of: {allowed_ids}", + ) + + @field_validator("json_doc_id") + @classmethod + def validate_doc_id(cls, v: str) -> str: + if v not in allowed_ids: + raise ValueError( + f"Document ID '{v}' not found. Available IDs: {sorted(allowed_ids)}" + ) + return v + + RemoveDoc.__name__ = "RemoveDoc" + return RemoveDoc + +def _ensure_patches(args: dict) -> list[Dict[str, Any]]: + """Process patches from different formats and ensure they're valid JsonPatch objects.""" + patches = args.get("patches", []) + + # If already a list, process it + if isinstance(patches, list): + processed_patches = [] + + for patch in patches: + if isinstance(patch, (dict, BaseModel)): + # Extract required fields + if isinstance(patch, BaseModel): + patch = patch.model_dump() if hasattr(patch, 'model_dump') else patch.dict() + + op = patch.get("op") + path = patch.get("path") + value = patch.get("value") + + # Verify required fields + if op and path: + # For remove operations, value can be None + if op == "remove": + processed_patches.append({"op": op, "path": path}) + # For add/replace operations, value is required + elif value is not None: + # Try to parse string values as JSON for complex values + parsed_value = value + if isinstance(value, str) and (value.startswith('{') or value.startswith('[')): + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + # If parsing fails, use value as is + parsed_value = value + + processed_patches.append({ + "op": op, + "path": path, + "value": parsed_value + }) + + return processed_patches + + # Handle string format + if isinstance(patches, str): + try: + # Direct JSON parsing attempt + parsed = json.loads(patches) + if isinstance(parsed, list): + return _ensure_patches({"patches": parsed}) + except json.JSONDecodeError: + # Fallback: Try to find a complete JSON array within the string + bracket_depth = 0 + first_list_str = None + start = patches.find("[") + if start != -1: + for i in range(start, len(patches)): + if patches[i] == "[": + bracket_depth += 1 + elif patches[i] == "]": + bracket_depth -= 1 + if bracket_depth == 0: + first_list_str = patches[start : i + 1] + break + if first_list_str: + try: + parsed = json.loads(first_list_str) + if isinstance(parsed, list): + return _ensure_patches({"patches": parsed}) + except json.JSONDecodeError: + pass + + return [] \ No newline at end of file diff --git a/trustcall/states.py b/trustcall/states.py new file mode 100644 index 0000000..267b3cc --- /dev/null +++ b/trustcall/states.py @@ -0,0 +1,146 @@ +from dataclasses import asdict, field, dataclass +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + Union, + Literal, + cast, + get_args, + get_origin, +) +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + MessageLikeRepresentation, + ToolMessage, +) +from langchain_core.prompt_values import PromptValue +from typing_extensions import Annotated, TypedDict +import operator + +class MessageOp(TypedDict): + op: Literal["delete", "update_tool_call", "update_tool_name"] + target: Union[str, Any] # ToolCall + +def _apply_message_ops( + messages: Sequence[AnyMessage], message_ops: Sequence[MessageOp] +) -> List[AnyMessage]: + """Apply operations to messages.""" + # Apply operations to the messages + messages = list(messages) + for message_op in message_ops: + if message_op["op"] == "delete": + t = cast(str, message_op["target"]) + messages_ = [m for m in messages if cast(str, getattr(m, "id")) != t] + messages = messages_ + elif message_op["op"] == "update_tool_call": + targ = cast(Any, message_op["target"]) + messages_ = [] + for m in messages: + if isinstance(m, AIMessage): + old = m.tool_calls.copy() + new = [ + targ if tc["id"] == targ["id"] else tc for tc in m.tool_calls + ] + if old != new: + m = m.model_copy() + m.tool_calls = new + if m.additional_kwargs.get("tool_calls"): + m.additional_kwargs["tool_calls"] = new + messages_.append(m) + else: + messages_.append(m) + messages = messages_ + elif message_op["op"] == "update_tool_name": + update_targ = cast(dict, message_op["target"]) + messages_ = [] + for m in messages: + if isinstance(m, AIMessage): + new = [] + for tc in m.tool_calls: + if tc["id"] == update_targ["id"]: + new.append( + { + "id": update_targ["id"], + "name": update_targ[ + "name" + ], # Just updating the name + "args": tc["args"], + } + ) + else: + new.append(tc) + if m.tool_calls != new: + m = m.model_copy() + m.tool_calls = new + messages_.append(m) + messages = messages_ + else: + raise ValueError(f"Invalid operation: {message_op['op']}") + return messages + +def _reduce_messages( + left: Optional[List[AnyMessage]], + right: Union[ + AnyMessage, + List[Union[AnyMessage, MessageOp]], + List[BaseMessage], + PromptValue, + MessageOp, + ], +) -> Sequence[MessageLikeRepresentation]: + """Combine two message sequences, handling message operations.""" + if not left: + left = [] + if isinstance(right, PromptValue): + right = right.to_messages() + message_ops = [] + if isinstance(right, dict) and right.get("op"): + message_ops = [right] + right = [] + if isinstance(right, list): + right_ = [] + for r in right: + if isinstance(r, dict) and r.get("op"): + message_ops.append(r) + else: + right_.append(r) + right = right_ # type: ignore[assignment] + from langgraph.graph import add_messages + messages = cast(Sequence[AnyMessage], add_messages(left, right)) # type: ignore[arg-type] + if message_ops: + messages = _apply_message_ops(messages, message_ops) + return messages + +def _keep_first(left: Any, right: Any): + """Keep the first non-empty value.""" + return left or right + +@dataclass(kw_only=True) +class ExtractionState: + messages: Annotated[List[AnyMessage], _reduce_messages] = field( + default_factory=list + ) + attempts: Annotated[int, operator.add] = field(default=0) + msg_id: Annotated[str, _keep_first] = field(default="") + """Set once and never changed. The ID of the message to be patched.""" + existing: Optional[Dict[str, Any]] = field(default=None) + """If you're updating an existing schema, provide the existing schema here.""" + + +@dataclass(kw_only=True) +class ExtendedExtractState(ExtractionState): + tool_call_id: str = field(default="") + """The ID of the tool call to be patched.""" + bump_attempt: bool = field(default=False) + + +@dataclass(kw_only=True) +class DeletionState(ExtractionState): + deletion_target: str = field(default="") diff --git a/trustcall/tools.py b/trustcall/tools.py new file mode 100644 index 0000000..41098e4 --- /dev/null +++ b/trustcall/tools.py @@ -0,0 +1,163 @@ +"""Tool-related functionality for the trustcall package.""" + +from __future__ import annotations + +import inspect +from typing import ( + Any, + Callable, + Dict, + List, + Sequence, + Type, + Union, + cast, + get_args, +) + +from dydantic import create_model_from_schema +from langchain_core.tools import BaseTool, create_schema_from_function +from pydantic import BaseModel +from typing_extensions import get_origin, is_typeddict, Annotated + +from trustcall.utils import _strip_injected + +TOOL_T = Union[BaseTool, Type[BaseModel], Callable, Dict[str, Any]] +"""Type for tools that can be used with the extractor. + +Can be one of: +- BaseTool: A LangChain tool +- Type[BaseModel]: A Pydantic model class +- Callable: A function +- Dict[str, Any]: A dictionary representing a schema +""" + + +def ensure_tools( + tools: Sequence[TOOL_T], +) -> List[Union[BaseTool, Type[BaseModel], Callable]]: + """Convert various tool formats to a consistent format. + + Args: + tools: A sequence of tools in various formats + + Returns: + A list of tools in a consistent format + + Raises: + ValueError: If a tool is in an invalid format + """ + results: list = [] + for t in tools: + if isinstance(t, dict): + if all(k in t for k in ("name", "description", "parameters")): + schema = create_model_from_schema( + {"title": t["name"], **t["parameters"]} + ) + schema.__doc__ = (getattr(schema, __doc__, "") or "") + ( + t.get("description") or "" + ) + schema.__name__ = t["name"] + results.append(schema) + elif all(k in t for k in ("type", "function")): + # Already in openai format + resolved = ensure_tools([t["function"]]) + results.extend(resolved) + else: + model = create_model_from_schema(t) + if not model.__doc__: + model.__doc__ = t.get("description") or model.__name__ + results.append(model) + elif is_typeddict(t): + results.append(_convert_any_typed_dicts_to_pydantic(cast(type, t))) + elif isinstance(t, (BaseTool, type)): + results.append(t) + elif callable(t): + results.append(csff_(t)) + else: + raise ValueError(f"Invalid tool type: {type(t)}") + return list(results) + + +def csff_(function: Callable) -> Type[BaseModel]: + """Create a schema from a function. + + Args: + function: The function to create a schema from + + Returns: + A Pydantic model class representing the function's schema + """ + fn = _strip_injected(function) + schema = create_schema_from_function(function.__name__, fn) + schema.__name__ = function.__name__ + return schema + + +_MAX_TYPED_DICT_RECURSION = 25 + + +def _convert_any_typed_dicts_to_pydantic( + type_: type, + *, + visited: dict | None = None, + depth: int = 0, +) -> type: + """Convert TypedDict to Pydantic model. + + Args: + type_: The type to convert + visited: A dictionary of already visited types + depth: The current recursion depth + + Returns: + The converted type + """ + from pydantic import Field, create_model + + visited = visited if visited is not None else {} + if type_ in visited: + return visited[type_] + elif depth >= _MAX_TYPED_DICT_RECURSION: + return type_ + elif is_typeddict(type_): + typed_dict = type_ + docstring = inspect.getdoc(typed_dict) + annotations_ = typed_dict.__annotations__ + fields: dict = {} + for arg, arg_type in annotations_.items(): + if get_origin(arg_type) is Annotated: + annotated_args = get_args(arg_type) + new_arg_type = _convert_any_typed_dicts_to_pydantic( + annotated_args[0], depth=depth + 1, visited=visited + ) + field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) + if (field_desc := field_kwargs.get("description")) and not isinstance( + field_desc, str + ): + raise ValueError( + f"Invalid annotation for field {arg}. Third argument to " + f"Annotated must be a string description, received value of " + f"type {type(field_desc)}." + ) + else: + pass + fields[arg] = (new_arg_type, Field(**field_kwargs)) + else: + new_arg_type = _convert_any_typed_dicts_to_pydantic( + arg_type, depth=depth + 1, visited=visited + ) + field_kwargs = {"default": ...} + fields[arg] = (new_arg_type, Field(**field_kwargs)) + model = create_model(typed_dict.__name__, **fields) + model.__doc__ = docstring or "" + visited[typed_dict] = model + return model + elif (origin := get_origin(type_)) and (type_args := get_args(type_)): + type_args = tuple( + _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) + for arg in type_args # type: ignore[index] + ) + return origin[type_args] # type: ignore[index] + else: + return type_ \ No newline at end of file diff --git a/trustcall/types.py b/trustcall/types.py new file mode 100644 index 0000000..853591f --- /dev/null +++ b/trustcall/types.py @@ -0,0 +1,92 @@ +"""Type definitions for the trustcall package.""" + +from __future__ import annotations + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) + +from langchain_core.messages import ( + AnyMessage, + MessageLikeRepresentation, +) +from langchain_core.prompt_values import PromptValue +from typing_extensions import TypedDict + + +class SchemaInstance(tuple): + """Represents an instance of a schema with its associated metadata. + + This named tuple is used to store information about a specific schema instance, + including its unique identifier, the name of the schema it conforms to, + and the actual data of the record. + + Attributes: + record_id (str): A unique identifier for this schema instance. + schema_name (str): The name of the schema that this instance conforms to. + record (dict[str, Any]): The actual data of the record, stored as a dictionary. + """ + + record_id: str + schema_name: str | Literal["__any__"] + record: Dict[str, Any] + + def __new__(cls, record_id, schema_name, record): + return tuple.__new__(cls, (record_id, schema_name, record)) + + @property + def record_id(self) -> str: + return self[0] + + @property + def schema_name(self) -> str | Literal["__any__"]: + return self[1] + + @property + def record(self) -> Dict[str, Any]: + return self[2] + + +ExistingType = Union[ + Dict[str, Any], List[SchemaInstance], List[tuple[str, str, dict[str, Any]]] +] +"""Type for existing schemas. + +Can be one of: +- Dict[str, Any]: A dictionary mapping schema names to schema instances. +- List[SchemaInstance]: A list of SchemaInstance named tuples. +- List[tuple[str, str, dict[str, Any]]]: A list of tuples containing + (record_id, schema_name, record_dict). + +This type allows for flexibility in representing existing schemas, +supporting both single and multiple instances of each schema type. +""" + + +class ExtractionInputs(TypedDict, total=False): + messages: Union[Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], PromptValue] + existing: Optional[ExistingType] + """Existing schemas. Key is the schema name, value is the schema instance. + If a list, supports duplicate schemas to update. + """ + + +InputsLike = Union[ExtractionInputs, List[AnyMessage], PromptValue, str] + + +class ExtractionOutputs(TypedDict): + messages: List[Any] # AIMessage + responses: List[Any] # BaseModel + response_metadata: List[dict[str, Any]] + attempts: int + + +Message = Union[AnyMessage, MessageLikeRepresentation] + +Messages = Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]] \ No newline at end of file diff --git a/trustcall/utils.py b/trustcall/utils.py new file mode 100644 index 0000000..e46c62c --- /dev/null +++ b/trustcall/utils.py @@ -0,0 +1,136 @@ +"""Utility functions for the trustcall package.""" + +from __future__ import annotations + +import functools +import inspect +import json +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Type, + get_args, +) + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + MessageLikeRepresentation, + ToolMessage, +) +from langchain_core.prompt_values import PromptValue +from langchain_core.tools import InjectedToolArg + +logger = logging.getLogger("extraction") + + +def is_gemini_model(llm: BaseChatModel) -> bool: + """Determine if the provided LLM is a Google Vertex AI Gemini model.""" + # Check based on class module path + if hasattr(llm, "__class__") and hasattr(llm.__class__, "__module__"): + module_path = llm.__class__.__module__.lower() + is_gemini_by_module = any(term in module_path for term in ["vertex", "google", "gemini"]) + if is_gemini_by_module: + return True + + # Check based on model name, if available + model_name = getattr(llm, "model_name", "") or "" + is_gemini_by_name = isinstance(model_name, str) and "gemini" in model_name.lower() + if is_gemini_by_name: + return True + + return False + + +def _exclude_none(d: Dict[str, Any]) -> Dict[str, Any]: + """Remove None values from a dictionary recursively.""" + return { + k: v if not isinstance(v, dict) else _exclude_none(v) + for k, v in d.items() + if v is not None + } + + + +def _is_injected_arg_type(type_: Type) -> bool: + """Check if a type is an injected argument type.""" + return any( + isinstance(arg, InjectedToolArg) + or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) + for arg in get_args(type_)[1:] + ) + + +def _curry(func: Callable, **fixed_kwargs: Any) -> Callable: + """Bind parameters to a function, removing those parameters from the signature. + + Useful for exposing a narrower interface than what the the original function + provides. + """ + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + new_kwargs = {**fixed_kwargs, **kwargs} + return func(*args, **new_kwargs) + + sig = inspect.signature(func) + # Check that fixed_kwargs are all valid parameters of the function + invalid_kwargs = set(fixed_kwargs) - set(sig.parameters) + if invalid_kwargs: + raise ValueError(f"Invalid parameters: {invalid_kwargs}") + + new_params = [p for name, p in sig.parameters.items() if name not in fixed_kwargs] + wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore + return wrapper + + +def _strip_injected(fn: Callable) -> Callable: + """Strip injected arguments from a function's signature.""" + injected = [ + p.name + for p in inspect.signature(fn).parameters.values() + if _is_injected_arg_type(p.annotation) + ] + return _curry(fn, **{k: None for k in injected}) + + +def _try_parse_json_value(value): + """Try to parse a string value as JSON if it looks like JSON.""" + if isinstance(value, str) and (value.startswith('{') or value.startswith('[')): + try: + return json.loads(value) + except json.JSONDecodeError: + pass + return value + + +def _get_history_for_tool_call(messages: List[AnyMessage], tool_call_id: str): + """Get the history of messages related to a specific tool call.""" + results = [] + seen_ai_message = False + for m in reversed(messages): + if isinstance(m, AIMessage): + if not seen_ai_message: + tool_calls = [tc for tc in m.tool_calls if tc["id"] == tool_call_id] + if hasattr(m, "model_dump"): + d = m.model_dump(exclude={"tool_calls", "content"}) + else: + d = m.dict(exclude={"tool_calls", "content"}) + m = AIMessage( + **d, + # Frequently have partial_json blocks that are + # invalid if sent back to the API + content=str(m.content), + tool_calls=tool_calls, + ) + seen_ai_message = True + if isinstance(m, ToolMessage): + if m.tool_call_id != tool_call_id and not seen_ai_message: + continue + results.append(m) + return list(reversed(results)) \ No newline at end of file diff --git a/trustcall/validation.py b/trustcall/validation.py new file mode 100644 index 0000000..d65ef6d --- /dev/null +++ b/trustcall/validation.py @@ -0,0 +1,97 @@ +"""Validation-related functionality for the trustcall package.""" + +from __future__ import annotations + +import logging +from typing import Any, cast + +from langchain_core.messages import ( + AIMessage, + AnyMessage, + ToolCall, + ToolMessage, +) +from langchain_core.runnables import RunnableConfig +from langgraph.prebuilt.tool_validator import ValidationNode, get_executor_for_config +from dataclasses import asdict + +logger = logging.getLogger("extraction") + + +class _ExtendedValidationNode(ValidationNode): + """Extended validation node with support for deletion.""" + + def __init__(self, *args, enable_deletes: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.enable_deletes = enable_deletes + + def _func(self, input: Any, config: RunnableConfig) -> Any: # type: ignore + """Validate and run tool calls synchronously.""" + output_type, message = self._get_message(asdict(input)) + removal_schema = None + if self.enable_deletes and hasattr(input, "existing") and input.existing: + from trustcall.schema import _create_remove_doc_from_existing + removal_schema = _create_remove_doc_from_existing(input.existing) + + # ADDED: Get the current attempt count from the state + attempt_count = input.attempts if hasattr(input, 'attempts') else 1 + logger.debug(f"Current validation attempt: {attempt_count}") + + def run_one(call: ToolCall): # type: ignore + logger.debug(f"Validating tool call: {call['name']} with args: {call['args']}") + try: + if removal_schema and call["name"] == removal_schema.__name__: + schema = removal_schema + logger.debug(f"Using removal schema: {removal_schema.__name__}") + else: + schema = self.schemas_by_name[call["name"]] + logger.debug(f"Using schema: {call['name']}") + + try: + # ADDED: Create validation context with attempt count + validation_context = {"attempt_count": attempt_count} + logger.debug(f"Created validation context: {validation_context}") + + # MODIFIED: Pass context to model_validate + output = schema.model_validate(call["args"], context=validation_context) + logger.debug(f"Validation successful: {output}") + # output = schema.model_validate(call["args"]) + return ToolMessage( + content=output.model_dump_json(), + name=call["name"], + tool_call_id=cast(str, call["id"]), + ) + except Exception as validation_error: + # Add detailed logging about validation failures + logger.debug(f"Validation error in schema {call['name']}: {str(validation_error)}") + raise validation_error + + except KeyError: + valid_names = ", ".join(self.schemas_by_name.keys()) + logger.debug(f"Unknown tool name: {call['name']}. Valid names: {valid_names}") + return ToolMessage( + content=f'Unrecognized tool name: "{call["name"]}". You only have' + f" access to the following tools: {valid_names}." + " Please call PatchFunctionName with the *correct* tool name" + f" to fix json_doc_id=[{call['id']}].", + name=call["name"], + tool_call_id=cast(str, call["id"]), + status="error", + ) + except Exception as e: + logger.debug(f"Exception during validation: {type(e).__name__}: {str(e)}") + error_message = self._format_error(e, call, schema) + logger.debug(f"Formatted error message: {error_message}") + return ToolMessage( + content=error_message, + name=call["name"], + tool_call_id=cast(str, call["id"]), + status="error", + ) + + with get_executor_for_config(config) as executor: + outputs = [*executor.map(run_one, message.tool_calls)] + if output_type == "list": + return outputs + else: + return {"messages": outputs} \ No newline at end of file From 8422f1225f77d2373b94b75b0bef88cbd2297db8 Mon Sep 17 00:00:00 2001 From: Carl Gabel Date: Wed, 5 Mar 2025 23:53:28 +0000 Subject: [PATCH 2/4] Removed duplicate value definition in GeminiJsonPatch --- trustcall/schema.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/trustcall/schema.py b/trustcall/schema.py index 0b75d64..d23303d 100644 --- a/trustcall/schema.py +++ b/trustcall/schema.py @@ -226,13 +226,6 @@ class GeminiJsonPatch(BasePatch): Note that the op and path are ALWAYS required. Value is required for ALL operations except 'remove'. This supports Gemini with it's more limited JSON compatibility. """ # noqa - # Similar to JsonPatch but with Gemini-compatible schema definition - # Instead of using a string-only value, use Union types that match Gemini's schema - value: Optional[Union[str, int, float, bool, List, Dict]] = Field( - default=None, - description="The value to be used within the operation. Required for" - " 'add' and 'replace' operations, not needed for 'remove'." - ) # For Gemini, we'll use a string value but with clear documentation that it can be complex value: Optional[str] = Field( @@ -276,7 +269,8 @@ def validate_value(cls, v, info): }, "value": { "type": "STRING", - "description": "Value to use in the operation. For complex values, use JSON strings." + "description": "The value to be used within the operation. For complex values (objects, arrays), " + "provide valid JSON as a string. Required for 'add' and 'replace' operations." } }, "required": ["op", "path"] From 0f9b54bcd5a8832312096c104c438851f480e062 Mon Sep 17 00:00:00 2001 From: Carl Gabel Date: Thu, 6 Mar 2025 00:05:42 +0000 Subject: [PATCH 3/4] Removed some excess debug logging. --- trustcall/validation.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/trustcall/validation.py b/trustcall/validation.py index d65ef6d..7218dad 100644 --- a/trustcall/validation.py +++ b/trustcall/validation.py @@ -35,40 +35,28 @@ def _func(self, input: Any, config: RunnableConfig) -> Any: # type: ignore # ADDED: Get the current attempt count from the state attempt_count = input.attempts if hasattr(input, 'attempts') else 1 - logger.debug(f"Current validation attempt: {attempt_count}") def run_one(call: ToolCall): # type: ignore - logger.debug(f"Validating tool call: {call['name']} with args: {call['args']}") try: if removal_schema and call["name"] == removal_schema.__name__: schema = removal_schema - logger.debug(f"Using removal schema: {removal_schema.__name__}") else: schema = self.schemas_by_name[call["name"]] - logger.debug(f"Using schema: {call['name']}") - try: # ADDED: Create validation context with attempt count validation_context = {"attempt_count": attempt_count} - logger.debug(f"Created validation context: {validation_context}") - # MODIFIED: Pass context to model_validate output = schema.model_validate(call["args"], context=validation_context) - logger.debug(f"Validation successful: {output}") - # output = schema.model_validate(call["args"]) return ToolMessage( content=output.model_dump_json(), name=call["name"], tool_call_id=cast(str, call["id"]), ) except Exception as validation_error: - # Add detailed logging about validation failures - logger.debug(f"Validation error in schema {call['name']}: {str(validation_error)}") raise validation_error except KeyError: valid_names = ", ".join(self.schemas_by_name.keys()) - logger.debug(f"Unknown tool name: {call['name']}. Valid names: {valid_names}") return ToolMessage( content=f'Unrecognized tool name: "{call["name"]}". You only have' f" access to the following tools: {valid_names}." @@ -79,9 +67,7 @@ def run_one(call: ToolCall): # type: ignore status="error", ) except Exception as e: - logger.debug(f"Exception during validation: {type(e).__name__}: {str(e)}") error_message = self._format_error(e, call, schema) - logger.debug(f"Formatted error message: {error_message}") return ToolMessage( content=error_message, name=call["name"], From 2bba6653f406d27b6917c25ac10f596b5f2dac5f Mon Sep 17 00:00:00 2001 From: Carl Gabel Date: Mon, 10 Mar 2025 03:42:05 +0000 Subject: [PATCH 4/4] Handling for rare "'ExtractionState' object has no attribute 'tool_call_id'" error. --- trustcall/extract.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/trustcall/extract.py b/trustcall/extract.py index 5e1124d..bf056e0 100644 --- a/trustcall/extract.py +++ b/trustcall/extract.py @@ -755,27 +755,35 @@ def validate_or_retry( builder.add_edge("extract", "validate") builder.add_conditional_edges("extract_updates", validate_or_retry) - def handle_retries( - state: ExtractionState, config: RunnableConfig - ) -> Union[Literal["__end__"], list]: + def handle_retries(state: ExtractionState, config: RunnableConfig) -> Union[Literal["__end__"], list]: """After validation, decide whether to retry or end the process.""" max_attempts = config["configurable"].get("max_attempts", DEFAULT_MAX_ATTEMPTS) if state.attempts >= max_attempts: return "__end__" # Only continue if we need to patch the tool call to_send = [] - # We only increment the attempt count once, regardless of the fan-out - # degree. bumped = False + + # Add defensive check - ensure there's at least one AIMessage in history + has_ai_message = any(isinstance(m, AIMessage) for m in state.messages) + if not has_ai_message: + logger.warning("No AIMessage found in state.messages, ending processing") + return "__end__" + for m in reversed(state.messages): if isinstance(m, AIMessage): break if isinstance(m, ToolMessage): if m.status == "error": - # Each fallback will fix at most 1 schema per time. messages_for_fixing = _get_history_for_tool_call( state.messages, m.tool_call_id ) + + # Ensure tool_call_id is properly set + if not hasattr(m, "tool_call_id") or not m.tool_call_id: + logger.warning(f"Missing tool_call_id on message {m}, skipping") + continue + to_send.append( Send( "patch", @@ -791,6 +799,10 @@ def handle_retries( ) bumped = True else: + # Safe deletion handling + if not hasattr(m, "id") or not m.id: + logger.warning(f"Missing id on message {m}, skipping deletion") + continue # We want to delete the validation tool calls # anyway to avoid mixing branches during fan-in to_send.append(