Skip to content

Commit 213ef53

Browse files
hinthornwaojah1
andcommitted
Use a more general name (#31)
Fixed tool call to work with vLLM and OCI DataScience Model Deployment API Co-authored-by: Anup Ojah <aojah1@yahoo.com>
1 parent 0c27fe5 commit 213ef53

File tree

5 files changed

+14
-13
lines changed

5 files changed

+14
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies = [
1010
"jsonpatch<2.0,>=1.33",
1111
]
1212
name = "trustcall"
13-
version = "0.0.32"
13+
version = "0.0.34"
1414
description = "Tenacious & trustworthy tool calling built on LangGraph."
1515
readme = "README.md"
1616

tests/unit_tests/test_extraction.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,13 @@ def test_validate_existing(existing, tools, is_valid):
422422
extractor._validate_existing(existing)
423423

424424

425-
@pytest.mark.asyncio
426425
@pytest.mark.parametrize("strict_mode", [True, False, "ignore"])
427426
async def test_e2e_existing_schema_policy_behavior(strict_mode):
428427
class MyRecognizedSchema(BaseModel):
429428
"""A recognized schema that the pipeline can handle."""
430429

431-
user_id: str
432-
notes: str
430+
user_id: str # type: ignore
431+
notes: str # type: ignore
433432

434433
# Our existing data includes 2 top-level keys: recognized, unknown
435434
existing_schemas = {
@@ -537,14 +536,13 @@ class MyRecognizedSchema(BaseModel):
537536
assert recognized_item.notes == "updated notes"
538537

539538

540-
@pytest.mark.asyncio
541539
@pytest.mark.parametrize("strict_mode", [True, False, "ignore"])
542540
async def test_e2e_existing_schema_policy_tuple_behavior(strict_mode):
543541
class MyRecognizedSchema(BaseModel):
544542
"""A recognized schema that the pipeline can handle."""
545543

546-
user_id: str
547-
notes: str
544+
user_id: str # type: ignore
545+
notes: str # type: ignore
548546

549547
existing_schemas = [
550548
(
@@ -655,7 +653,6 @@ class MyRecognizedSchema(BaseModel):
655653
assert recognized_item.notes == "updated notes"
656654

657655

658-
@pytest.mark.asyncio
659656
@pytest.mark.parametrize("enable_inserts", [True, False])
660657
async def test_enable_deletes_flow(enable_inserts: bool) -> None:
661658
class MySchema(BaseModel):

tests/unit_tests/test_strict_existing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from unittest.mock import patch
23

34
import pytest
45
from langchain_openai import ChatOpenAI
@@ -260,7 +261,8 @@ def test_validate_existing_strictness(
260261
):
261262
"""Test various scenarios of validation."""
262263
tools = {"DummySchema": DummySchema}
263-
llm = ChatOpenAI(model="gpt-4o-mini")
264+
with patch.dict("os.environ", {"OPENAI_API_KEY": "fake-api-key"}):
265+
llm = ChatOpenAI(model="gpt-4o-mini")
264266
extractor = _ExtractUpdates(
265267
llm=llm, # We won't actually call the LLM here but we need it for parsing.
266268
tools=tools,

trustcall/_base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def __init__(
664664
existing_schema_policy: bool | Literal["ignore"] = True,
665665
):
666666
new_tools: list = [PatchDoc]
667-
tool_choice = PatchDoc.__name__ if not enable_deletes else "any"
667+
tool_choice = "PatchDoc" if not enable_deletes else "any"
668668
if enable_inserts: # Also let the LLM know that we can extract NEW schemas.
669669
tools_ = [
670670
schema
@@ -1052,7 +1052,7 @@ def _tear_down(
10521052

10531053
async def ainvoke(
10541054
self, state: ExtendedExtractState, config: RunnableConfig
1055-
) -> dict:
1055+
) -> Command[Literal["sync", "__end__"]]:
10561056
"""Generate a JSONPatch to correct the validation error and heal the tool call.
10571057
10581058
Assumptions:
@@ -1075,7 +1075,9 @@ async def ainvoke(
10751075
goto=("sync",),
10761076
)
10771077

1078-
def invoke(self, state: ExtendedExtractState, config: RunnableConfig) -> dict:
1078+
def invoke(
1079+
self, state: ExtendedExtractState, config: RunnableConfig
1080+
) -> Command[Literal["sync", "__end__"]]:
10791081
try:
10801082
msg = self.bound.invoke(state.messages, config)
10811083
except Exception:

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)