Skip to content
4 changes: 3 additions & 1 deletion libs/langchain/langchain/agents/openai_assistant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def invoke(
run_manager.on_chain_error(e)
raise
try:
# Use sync response handler in sync invoke
response = self._get_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
Expand Down Expand Up @@ -496,7 +497,8 @@ async def ainvoke(
run_manager.on_chain_error(e)
raise
try:
response = self._get_response(run)
# Use async response handler in async ainvoke
response = await self._aget_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
raise
Expand Down
78 changes: 77 additions & 1 deletion libs/langchain/tests/unit_tests/agents/test_openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:

@pytest.mark.requires("openai")
def test_user_supplied_client() -> None:
import openai
openai = pytest.importorskip("openai")

client = openai.AzureOpenAI(
azure_endpoint="azure_endpoint",
Expand Down Expand Up @@ -48,6 +48,82 @@ def test_create_assistant() -> None:
assert isinstance(assistant, OpenAIAssistantRunnable)


@pytest.mark.requires("openai")
@patch(
"langchain.agents.openai_assistant.base._get_openai_async_client",
new=partial(_create_mock_client, use_async=True),
)
async def test_ainvoke_uses_async_response_completed() -> None:
# Arrange a runner with mocked async client and a completed run
assistant = OpenAIAssistantRunnable(
assistant_id="assistant_id",
client=_create_mock_client(),
async_client=_create_mock_client(use_async=True),
as_agent=False,
)
mock_run = MagicMock()
mock_run.id = "run-id"
mock_run.thread_id = "thread-id"
mock_run.status = "completed"

# await_for_run returns a completed run
await_for_run_mock = AsyncMock(return_value=mock_run)
# async messages list returns messages belonging to run
msg = MagicMock()
msg.run_id = "run-id"
msg.content = []
list_mock = AsyncMock(return_value=[msg])

with patch.object(assistant, "_await_for_run", await_for_run_mock), patch.object(
assistant.async_client.beta.threads.messages,
"list",
list_mock,
):
# Act
result = await assistant.ainvoke({"content": "hi"})

# Assert: returns messages list (non-agent path) and did not block
assert isinstance(result, list)
list_mock.assert_awaited()


@pytest.mark.requires("openai")
@patch(
"langchain.agents.openai_assistant.base._get_openai_async_client",
new=partial(_create_mock_client, use_async=True),
)
async def test_ainvoke_uses_async_response_requires_action_agent() -> None:
# Arrange a runner with mocked async client and requires_action run
assistant = OpenAIAssistantRunnable(
assistant_id="assistant_id",
client=_create_mock_client(),
async_client=_create_mock_client(use_async=True),
as_agent=True,
)
mock_run = MagicMock()
mock_run.id = "run-id"
mock_run.thread_id = "thread-id"
mock_run.status = "requires_action"

# Fake tool call structure
tool_call = MagicMock()
tool_call.id = "tool-id"
tool_call.function.name = "foo"
tool_call.function.arguments = '{\n "x": 1\n}'
mock_run.required_action.submit_tool_outputs.tool_calls = [tool_call]

await_for_run_mock = AsyncMock(return_value=mock_run)

# Act
with patch.object(assistant, "_await_for_run", await_for_run_mock):
result = await assistant.ainvoke({"content": "hi"})

# Assert: returns list of OpenAIAssistantAction
assert isinstance(result, list)
assert result
assert getattr(result[0], "tool", None) == "foo"


@pytest.mark.requires("openai")
@patch(
"langchain.agents.openai_assistant.base._get_openai_async_client",
Expand Down
Loading