Skip to content

Commit 8b81ff4

Browse files
authored
FIX: switch off native tools on Anthropic Claude Opus (#659)
Native tools do not work well on Opus. Chain of Thought prompting means it consumes enormous amounts of tokens and has poor latency. This commit introduce and XML stripper to remove various chain of thought XML islands from anthropic prompts when tools are involved. This mean Opus native tools is now functions (albeit slowly) From local testing XML just works better now. Also fixes enum support in Anthropic native tools
1 parent 7a64699 commit 8b81ff4

File tree

13 files changed

+440
-67
lines changed

13 files changed

+440
-67
lines changed

config/locales/server.en.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ en:
5151
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
5252
ai_openai_api_key: "API key for OpenAI API"
5353
ai_anthropic_api_key: "API key for Anthropic API"
54+
ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools."
5455
ai_cohere_api_key: "API key for Cohere API"
5556
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
5657
ai_hugging_face_api_key: API key for Hugging Face API

config/settings.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ discourse_ai:
111111
ai_anthropic_api_key:
112112
default: ""
113113
secret: true
114+
ai_anthropic_native_tool_call_models:
115+
type: list
116+
list_type: compact
117+
default: "claude-3-sonnet|claude-3-haiku"
118+
allow_any: false
119+
choices:
120+
- claude-3-opus
121+
- claude-3-sonnet
122+
- claude-3-haiku
114123
ai_cohere_api_key:
115124
default: ""
116125
secret: true

lib/completions/dialects/claude.rb

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def initialize(system_prompt, messages, tools)
2222
@messages = messages
2323
@tools = tools
2424
end
25+
26+
def has_tools?
27+
tools.present?
28+
end
2529
end
2630

2731
def tokenizer
@@ -33,6 +37,10 @@ def translate
3337

3438
system_prompt = messages.shift[:content] if messages.first[:role] == "system"
3539

40+
if !system_prompt && !native_tool_support?
41+
system_prompt = tools_dialect.instructions.presence
42+
end
43+
3644
interleving_messages = []
3745
previous_message = nil
3846

@@ -48,11 +56,10 @@ def translate
4856
previous_message = message
4957
end
5058

51-
ClaudePrompt.new(
52-
system_prompt.presence,
53-
interleving_messages,
54-
tools_dialect.translated_tools,
55-
)
59+
tools = nil
60+
tools = tools_dialect.translated_tools if native_tool_support?
61+
62+
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
5663
end
5764

5865
def max_prompt_tokens
@@ -62,18 +69,28 @@ def max_prompt_tokens
6269
200_000 # Claude-3 has a 200k context window for now
6370
end
6471

72+
def native_tool_support?
73+
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
74+
end
75+
6576
private
6677

6778
def tools_dialect
68-
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
79+
if native_tool_support?
80+
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
81+
else
82+
super
83+
end
6984
end
7085

7186
def tool_call_msg(msg)
72-
tools_dialect.from_raw_tool_call(msg)
87+
translated = tools_dialect.from_raw_tool_call(msg)
88+
{ role: "assistant", content: translated }
7389
end
7490

7591
def tool_msg(msg)
76-
tools_dialect.from_raw_tool(msg)
92+
translated = tools_dialect.from_raw_tool(msg)
93+
{ role: "user", content: translated }
7794
end
7895

7996
def model_msg(msg)

lib/completions/dialects/claude_tools.rb

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ def translated_tools
1515
required = []
1616

1717
if t[:parameters]
18-
properties =
19-
t[:parameters].each_with_object({}) do |param, h|
20-
h[param[:name]] = {
21-
type: param[:type],
22-
description: param[:description],
23-
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
24-
end
18+
properties = {}
19+
20+
t[:parameters].each do |param|
21+
mapped = { type: param[:type], description: param[:description] }
22+
mapped[:items] = { type: param[:item_type] } if param[:item_type]
23+
mapped[:enum] = param[:enum] if param[:enum]
24+
properties[param[:name]] = mapped
25+
end
2526
required =
2627
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
2728
end
@@ -39,37 +40,24 @@ def translated_tools
3940
end
4041

4142
def instructions
42-
"" # Noop. Tools are listed separate.
43+
""
4344
end
4445

4546
def from_raw_tool_call(raw_message)
4647
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
4748
tool_call_id = raw_message[:id]
48-
49-
{
50-
role: "assistant",
51-
content: [
52-
{
53-
type: "tool_use",
54-
id: tool_call_id,
55-
name: raw_message[:name],
56-
input: call_details[:arguments],
57-
},
58-
],
59-
}
49+
[
50+
{
51+
type: "tool_use",
52+
id: tool_call_id,
53+
name: raw_message[:name],
54+
input: call_details[:arguments],
55+
},
56+
]
6057
end
6158

6259
def from_raw_tool(raw_message)
63-
{
64-
role: "user",
65-
content: [
66-
{
67-
type: "tool_result",
68-
tool_use_id: raw_message[:id],
69-
content: raw_message[:content],
70-
},
71-
],
72-
}
60+
[{ type: "tool_result", tool_use_id: raw_message[:id], content: raw_message[:content] }]
7361
end
7462

7563
private

lib/completions/dialects/xml_tools.rb

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,17 @@ def translated_tools
4141
def instructions
4242
return "" if raw_tools.blank?
4343

44-
has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
45-
46-
(<<~TEXT).strip
47-
#{tool_preamble(include_array_tip: has_arrays)}
48-
<tools>
49-
#{translated_tools}</tools>
50-
TEXT
44+
@instructions ||=
45+
begin
46+
has_arrays =
47+
raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
48+
49+
(<<~TEXT).strip
50+
#{tool_preamble(include_array_tip: has_arrays)}
51+
<tools>
52+
#{translated_tools}</tools>
53+
TEXT
54+
end
5155
end
5256

5357
def from_raw_tool(raw_message)

lib/completions/endpoints/anthropic.rb

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ def default_options(dialect)
4545
raise "Unsupported model: #{model}"
4646
end
4747

48-
{ model: mapped_model, max_tokens: 3_000 }
48+
options = { model: mapped_model, max_tokens: 3_000 }
49+
50+
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
51+
dialect.prompt.has_tools?
52+
53+
options
4954
end
5055

5156
def provider_id
@@ -54,6 +59,14 @@ def provider_id
5459

5560
private
5661

62+
def xml_tags_to_strip(dialect)
63+
if dialect.prompt.has_tools?
64+
%w[thinking search_quality_reflection search_quality_score]
65+
else
66+
[]
67+
end
68+
end
69+
5770
# this is an approximation, we will update it later if request goes through
5871
def prompt_size(prompt)
5972
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
@@ -66,11 +79,13 @@ def model_uri
6679
end
6780

6881
def prepare_payload(prompt, model_params, dialect)
82+
@native_tool_support = dialect.native_tool_support?
83+
6984
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
7085

7186
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
7287
payload[:stream] = true if @streaming_mode
73-
payload[:tools] = prompt.tools if prompt.tools.present?
88+
payload[:tools] = prompt.tools if prompt.has_tools?
7489

7590
payload
7691
end
@@ -108,7 +123,7 @@ def final_log_update(log)
108123
end
109124

110125
def native_tool_support?
111-
true
126+
@native_tool_support
112127
end
113128

114129
def partials_from(decoded_chunk)

lib/completions/endpoints/aws_bedrock.rb

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,24 @@ def normalize_model_params(model_params)
3636

3737
def default_options(dialect)
3838
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
39+
40+
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
41+
dialect.prompt.has_tools?
3942
options
4043
end
4144

4245
def provider_id
4346
AiApiAuditLog::Provider::Anthropic
4447
end
4548

49+
def xml_tags_to_strip(dialect)
50+
if dialect.prompt.has_tools?
51+
%w[thinking search_quality_reflection search_quality_score]
52+
else
53+
[]
54+
end
55+
end
56+
4657
private
4758

4859
def prompt_size(prompt)
@@ -79,9 +90,11 @@ def model_uri
7990
end
8091

8192
def prepare_payload(prompt, model_params, dialect)
93+
@native_tool_support = dialect.native_tool_support?
94+
8295
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
8396
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
84-
payload[:tools] = prompt.tools if prompt.tools.present?
97+
payload[:tools] = prompt.tools if prompt.has_tools?
8598

8699
payload
87100
end
@@ -169,7 +182,7 @@ def partials_from(decoded_chunks)
169182
end
170183

171184
def native_tool_support?
172-
true
185+
@native_tool_support
173186
end
174187

175188
def chunk_to_string(chunk)

lib/completions/endpoints/base.rb

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,27 @@ def use_ssl?
7878
end
7979
end
8080

81+
def xml_tags_to_strip(dialect)
82+
[]
83+
end
84+
8185
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
8286
allow_tools = dialect.prompt.has_tools?
8387
model_params = normalize_model_params(model_params)
88+
orig_blk = blk
8489

8590
@streaming_mode = block_given?
91+
to_strip = xml_tags_to_strip(dialect)
92+
@xml_stripper =
93+
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
94+
95+
if @streaming_mode && @xml_stripper
96+
blk =
97+
lambda do |partial, cancel|
98+
partial = @xml_stripper << partial
99+
orig_blk.call(partial, cancel) if partial
100+
end
101+
end
86102

87103
prompt = dialect.translate
88104

@@ -270,6 +286,11 @@ def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &bl
270286
blk.call(function_calls, cancel)
271287
end
272288

289+
if @xml_stripper
290+
leftover = @xml_stripper.finish
291+
orig_blk.call(leftover, cancel) if leftover.present?
292+
end
293+
273294
return response_data
274295
ensure
275296
if log

0 commit comments

Comments
 (0)