Skip to content

Commit 4b0c6d1

Browse files
authored
fix: image gen tool causing error (#5445)
1 parent da7dc33 commit 4b0c6d1

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed

.github/workflows/pr-external-dependency-unit-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ env:
2323

2424
# LLMs
2525
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
26+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
2627

2728
jobs:
2829
discover-test-dirs:

backend/onyx/tools/tool_constructor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def construct_tools(
193193
custom_tool_config: CustomToolConfig | None = None,
194194
allowed_tool_ids: list[int] | None = None,
195195
) -> dict[int, list[Tool]]:
196-
"""Constructs tools based on persona configuration and available APIs"""
196+
"""Constructs tools based on persona configuration and available APIs.
197+
198+
Will simply skip tools that are not allowed/available."""
197199
tool_dict: dict[int, list[Tool]] = {}
198200

199201
mcp_tool_cache: dict[int, dict[int, MCPTool]] = {}
@@ -210,6 +212,21 @@ def construct_tools(
210212
if db_tool_model.in_code_tool_id:
211213
tool_cls = get_built_in_tool_by_id(db_tool_model.in_code_tool_id)
212214

215+
try:
216+
tool_is_available = tool_cls.is_available(db_session)
217+
except Exception:
218+
logger.exception(
219+
"Failed checking availability for tool %s", tool_cls.__name__
220+
)
221+
tool_is_available = False
222+
223+
if not tool_is_available:
224+
logger.debug(
225+
"Skipping tool %s because it is not available",
226+
tool_cls.__name__,
227+
)
228+
continue
229+
213230
# Handle Search Tool
214231
if (
215232
tool_cls.__name__ == SearchTool.__name__
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from uuid import uuid4
5+
6+
from sqlalchemy.orm import Session
7+
8+
from onyx.chat.models import AnswerStreamPart
9+
from onyx.chat.models import MessageResponseIDInfo
10+
from onyx.chat.models import StreamingError
11+
from onyx.chat.process_message import stream_chat_message_objects
12+
from onyx.context.search.models import RetrievalDetails
13+
from onyx.db.chat import create_chat_session
14+
from onyx.db.llm import fetch_existing_llm_providers
15+
from onyx.db.llm import remove_llm_provider
16+
from onyx.db.llm import update_default_provider
17+
from onyx.db.llm import upsert_llm_provider
18+
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
19+
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
20+
from onyx.server.query_and_chat.models import CreateChatMessageRequest
21+
from onyx.server.query_and_chat.streaming_models import MessageDelta
22+
from onyx.server.query_and_chat.streaming_models import MessageStart
23+
from onyx.server.query_and_chat.streaming_models import Packet
24+
from tests.external_dependency_unit.conftest import create_test_user
25+
26+
27+
def test_answer_with_only_anthropic_provider(
28+
db_session: Session,
29+
full_deployment_setup: None,
30+
mock_external_deps: None,
31+
) -> None:
32+
"""Ensure chat still streams answers when only an Anthropic provider is configured."""
33+
34+
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
35+
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
36+
37+
# Drop any existing providers so that only Anthropic is available.
38+
for provider in fetch_existing_llm_providers(db_session):
39+
remove_llm_provider(db_session, provider.id)
40+
41+
anthropic_model = "claude-3-5-sonnet-20240620"
42+
provider_name = f"anthropic-test-{uuid4().hex}"
43+
44+
anthropic_provider = upsert_llm_provider(
45+
LLMProviderUpsertRequest(
46+
name=provider_name,
47+
provider="anthropic",
48+
api_key=anthropic_api_key,
49+
default_model_name=anthropic_model,
50+
fast_default_model_name=anthropic_model,
51+
is_public=True,
52+
groups=[],
53+
model_configurations=[
54+
ModelConfigurationUpsertRequest(name=anthropic_model, is_visible=True)
55+
],
56+
api_key_changed=True,
57+
),
58+
db_session=db_session,
59+
)
60+
61+
try:
62+
update_default_provider(anthropic_provider.id, db_session)
63+
64+
test_user = create_test_user(db_session, email_prefix="anthropic_only")
65+
chat_session = create_chat_session(
66+
db_session=db_session,
67+
description="Anthropic only chat",
68+
user_id=test_user.id,
69+
persona_id=0,
70+
)
71+
72+
chat_request = CreateChatMessageRequest(
73+
chat_session_id=chat_session.id,
74+
parent_message_id=None,
75+
message="hello",
76+
file_descriptors=[],
77+
search_doc_ids=None,
78+
retrieval_options=RetrievalDetails(),
79+
)
80+
81+
response_stream: list[AnswerStreamPart] = []
82+
for packet in stream_chat_message_objects(
83+
new_msg_req=chat_request,
84+
user=test_user,
85+
db_session=db_session,
86+
):
87+
response_stream.append(packet)
88+
89+
assert response_stream, "Should receive streamed packets"
90+
assert not any(
91+
isinstance(packet, StreamingError) for packet in response_stream
92+
), "No streaming errors expected with Anthropic provider"
93+
94+
has_message_id = any(
95+
isinstance(packet, MessageResponseIDInfo) for packet in response_stream
96+
)
97+
assert has_message_id, "Should include reserved assistant message ID"
98+
99+
has_message_start = any(
100+
isinstance(packet, Packet) and isinstance(packet.obj, MessageStart)
101+
for packet in response_stream
102+
)
103+
assert has_message_start, "Stream should have a MessageStart packet"
104+
105+
has_message_delta = any(
106+
isinstance(packet, Packet) and isinstance(packet.obj, MessageDelta)
107+
for packet in response_stream
108+
)
109+
assert has_message_delta, "Stream should have a MessageDelta packet"
110+
111+
finally:
112+
remove_llm_provider(db_session, anthropic_provider.id)

0 commit comments

Comments
 (0)