Skip to content

Commit e36e25f

Browse files
authored
feat(langchain): support PEP604 ( | union) in tool node error handlers (#32861)
This allows to use PEP604 syntax for `ToolNode` error handlers ```python def error_handler(e: ValueError | ToolException) -> str: return "error" ToolNode(my_tool, handle_tool_errors=error_handler).invoke(...) ``` Without this change, this fails with `AttributeError: 'types.UnionType' object has no attribute '__mro__'`
1 parent cc3b5af commit e36e25f

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

libs/langchain_v1/langchain/agents/tool_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def my_tool(x: int) -> str:
4040
import json
4141
from copy import copy, deepcopy
4242
from dataclasses import replace
43+
from types import UnionType
4344
from typing import (
4445
TYPE_CHECKING,
4546
Annotated,
@@ -246,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
246247
type_hints = get_type_hints(handler)
247248
if first_param.name in type_hints:
248249
origin = get_origin(first_param.annotation)
249-
if origin is Union:
250+
if origin in [Union, UnionType]:
250251
args = get_args(first_param.annotation)
251252
if all(issubclass(arg, Exception) for arg in args):
252253
return tuple(args)

libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import inspect
3+
from types import UnionType
34
from typing import (
45
Annotated,
56
Union,
@@ -343,16 +344,19 @@ def handle(e) -> str: # type: ignore
343344
def handle2(e: Exception) -> str:
344345
return ""
345346

346-
def handle3(e: Union[ValueError, ToolException]) -> str:
347+
def handle3(e: ValueError | ToolException) -> str:
348+
return ""
349+
350+
def handle4(e: Union[ValueError, ToolException]) -> str:
347351
return ""
348352

349353
class Handler:
350354
def handle(self, e: ValueError) -> str:
351355
return ""
352356

353-
handle4 = Handler().handle
357+
handle5 = Handler().handle
354358

355-
def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
359+
def handle6(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
356360
return ""
357361

358362
expected: tuple = (Exception,)
@@ -367,14 +371,18 @@ def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
367371
actual = _infer_handled_types(handle3)
368372
assert expected == actual
369373

370-
expected = (ValueError,)
374+
expected = (ValueError, ToolException)
371375
actual = _infer_handled_types(handle4)
372376
assert expected == actual
373377

374-
expected = (TypeError, ValueError, ToolException)
378+
expected = (ValueError,)
375379
actual = _infer_handled_types(handle5)
376380
assert expected == actual
377381

382+
expected = (TypeError, ValueError, ToolException)
383+
actual = _infer_handled_types(handle6)
384+
assert expected == actual
385+
378386
with pytest.raises(ValueError):
379387

380388
def handler(e: str) -> str:

libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_tool_node_error_handling_default_exception() -> None:
272272

273273

274274
async def test_tool_node_error_handling() -> None:
275-
def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]):
275+
def handle_all(e: ValueError | ToolException | ToolInvocationError):
276276
return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
277277

278278
# test catching all exceptions, via:

0 commit comments

Comments
 (0)