Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 564d2de

Browse files
authored
FEATURE: Add native Cohere tool support (#655)
Add native Cohere tool support - Introduce CohereTools class for tool translation and result processing - Update Command dialect to integrate with CohereTools - Modify Cohere endpoint to support passing tools and processing tool calls - Add spec for testing tool triggering with Cohere endpoint
1 parent 97afda2 commit 564d2de

File tree

5 files changed

+271
-27
lines changed

5 files changed

+271
-27
lines changed

lib/ai_bot/bot.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def reply(context, &update_blk)
6363
llm_kwargs[:temperature] = persona.temperature if persona.temperature
6464
llm_kwargs[:top_p] = persona.top_p if persona.top_p
6565

66+
needs_newlines = false
67+
6668
while total_completions <= MAX_COMPLETIONS && ongoing_chain
6769
tool_found = false
6870

@@ -72,11 +74,18 @@ def reply(context, &update_blk)
7274

7375
if (tools.present?)
7476
tool_found = true
77+
# a bit hacky, but extra newlines do no harm
78+
if needs_newlines
79+
update_blk.call("\n\n", cancel, nil)
80+
needs_newlines = false
81+
end
82+
7583
tools[0..MAX_TOOLS].each do |tool|
7684
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
7785
ongoing_chain &&= tool.chain_next_response?
7886
end
7987
else
88+
needs_newlines = true
8089
update_blk.call(partial, cancel, nil)
8190
end
8291
end
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
module Dialects
6+
class CohereTools
7+
def initialize(tools)
8+
@raw_tools = tools
9+
end
10+
11+
def tool_results(messages)
12+
pairs = []
13+
14+
current_pair = nil
15+
messages.each do |msg|
16+
if current_pair == nil && msg[:type] == :tool_call
17+
current_pair = [msg]
18+
elsif current_pair && msg[:type] == :tool
19+
current_pair << msg
20+
pairs << current_pair
21+
current_pair = nil
22+
else
23+
current_pair = nil
24+
end
25+
end
26+
27+
pairs.map do |call, result|
28+
params = JSON.parse(call[:content])["arguments"]
29+
{
30+
call: {
31+
name: call[:name] == "search" ? "search_local" : call[:name],
32+
parameters: params,
33+
generation_id: call[:id],
34+
},
35+
outputs: [JSON.parse(result[:content])],
36+
}
37+
end
38+
end
39+
40+
def translated_tools
41+
raw_tools.map do |t|
42+
tool = t.dup
43+
44+
tool[:parameter_definitions] = t[:parameters]
45+
.to_a
46+
.reduce({}) do |memo, p|
47+
name = p[:name]
48+
memo[name] = {
49+
description: p[:description],
50+
type: cohere_type(p[:type], p[:item_type]),
51+
required: p[:required],
52+
}
53+
54+
memo[name][:default] = p[:default] if p[:default]
55+
memo
56+
end
57+
58+
{
59+
name: tool[:name] == "search" ? "search_local" : tool[:name],
60+
description: tool[:description],
61+
parameter_definitions: tool[:parameter_definitions],
62+
}
63+
end
64+
end
65+
66+
def instructions
67+
"" # Noop. Tools are listed separate.
68+
end
69+
70+
private
71+
72+
attr_reader :raw_tools
73+
74+
def cohere_type(type, item_type)
75+
case type
76+
when "string"
77+
"str"
78+
when "number"
79+
item_type == "integer" ? "int" : "float"
80+
when "boolean"
81+
"bool"
82+
when "object"
83+
item_type ? "Dict[#{item_type}]" : "Dict"
84+
when "array"
85+
item_type ? "List[#{item_type}]" : "List"
86+
else
87+
type
88+
end
89+
end
90+
end
91+
end
92+
end
93+
end

lib/completions/dialects/command.rb

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,43 @@ def translate
2424
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
2525

2626
prompt = { preamble: +"#{system_message}" }
27-
prompt[:chat_history] = messages if messages.present?
2827

29-
messages.reverse_each do |msg|
30-
if msg[:role] == "USER"
31-
prompt[:message] = msg[:message]
32-
messages.delete(msg)
33-
break
28+
if messages.present?
29+
with_mapped_tools = []
30+
31+
current_pair = nil
32+
messages.each do |msg|
33+
if current_pair == nil && msg[:type] == :tool_call
34+
current_pair = [msg]
35+
elsif current_pair && msg[:type] == :tool
36+
current_pair << msg
37+
tool_results = tools_dialect.tool_results(current_pair)
38+
with_mapped_tools << { role: "TOOL", message: "", tool_results: tool_results }
39+
current_pair = nil
40+
else
41+
with_mapped_tools << msg
42+
current_pair = nil
43+
end
44+
end
45+
46+
messages = with_mapped_tools
47+
prompt[:chat_history] = messages
48+
end
49+
50+
tools = tools_dialect.translated_tools
51+
prompt[:tools] = tools if tools.present?
52+
53+
tool_results =
54+
messages.last && messages.last[:role] == "TOOL" && messages.last[:tool_results]
55+
prompt[:tool_results] = tool_results if tool_results.present?
56+
57+
if tool_results.blank?
58+
messages.reverse_each do |msg|
59+
if msg[:role] == "USER"
60+
prompt[:message] = msg[:message]
61+
messages.delete(msg)
62+
break
63+
end
3464
end
3565
end
3666

@@ -54,8 +84,16 @@ def max_prompt_tokens
5484
end
5585
end
5686

87+
def native_tool_support?
88+
true
89+
end
90+
5791
private
5892

93+
def tools_dialect
94+
@tools_dialect ||= DiscourseAi::Completions::Dialects::CohereTools.new(prompt.tools)
95+
end
96+
5997
def per_message_overhead
6098
0
6199
end
@@ -83,11 +121,11 @@ def model_msg(msg)
83121
end
84122

85123
def tool_call_msg(msg)
86-
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
124+
msg
87125
end
88126

89127
def tool_msg(msg)
90-
{ role: "USER", message: tools_dialect.from_raw_tool(msg) }
128+
msg
91129
end
92130

93131
def user_msg(msg)

lib/completions/endpoints/cohere.rb

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ def normalize_model_params(model_params)
2929
end
3030

3131
def default_options(dialect)
32-
options = { model: "command-r-plus" }
33-
34-
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
35-
options
32+
{ model: "command-r-plus" }
3633
end
3734

3835
def provider_id
@@ -49,7 +46,11 @@ def model_uri
4946

5047
def prepare_payload(prompt, model_params, dialect)
5148
payload = default_options(dialect).merge(model_params).merge(prompt)
52-
49+
if prompt[:tools].present?
50+
payload[:tools] = prompt[:tools]
51+
payload[:force_single_step] = false
52+
end
53+
payload[:tool_results] = prompt[:tool_results] if prompt[:tool_results].present?
5354
payload[:stream] = true if @streaming_mode
5455

5556
payload
@@ -70,6 +71,14 @@ def extract_completion_from(response_raw)
7071
if @streaming_mode
7172
if parsed[:event_type] == "text-generation"
7273
parsed[:text]
74+
elsif parsed[:event_type] == "tool-calls-generation"
75+
# could just be random thinking...
76+
if parsed.dig(:tool_calls).present?
77+
@has_tool = true
78+
parsed.dig(:tool_calls).to_json
79+
else
80+
""
81+
end
7382
else
7483
if parsed[:event_type] == "stream-end"
7584
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
@@ -84,6 +93,38 @@ def extract_completion_from(response_raw)
8493
end
8594
end
8695

96+
def has_tool?(_ignored)
97+
@has_tool
98+
end
99+
100+
def native_tool_support?
101+
true
102+
end
103+
104+
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
105+
if partial
106+
tools = JSON.parse(partial)
107+
tools.each do |tool|
108+
name = tool["name"]
109+
parameters = tool["parameters"]
110+
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
111+
112+
current_function = function_buffer.at("invoke")
113+
if current_function.nil? || current_function.at("tool_name").content.present?
114+
current_function =
115+
function_buffer.at("function_calls").add_child(
116+
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
117+
)
118+
end
119+
120+
current_function.at("tool_name").content = name == "search_local" ? "search" : name
121+
current_function.at("parameters").children =
122+
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
123+
end
124+
end
125+
function_buffer
126+
end
127+
87128
def final_log_update(log)
88129
log.request_tokens = @input_tokens if @input_tokens
89130
log.response_tokens = @output_tokens if @output_tokens

spec/lib/completions/endpoints/cohere_spec.rb

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,83 @@
5959

6060
before { SiteSetting.ai_cohere_api_key = "ABC" }
6161

62+
it "is able to trigger a tool" do
63+
body = (<<~TEXT).strip
64+
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
65+
{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}
66+
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"}
67+
TEXT
68+
69+
parsed_body = nil
70+
result = +""
71+
72+
sig = {
73+
name: "google",
74+
description: "Will search using Google",
75+
parameters: [
76+
{ name: "query", description: "The search query", type: "string", required: true },
77+
],
78+
}
79+
80+
prompt.tools = [sig]
81+
82+
EndpointMock.with_chunk_array_support do
83+
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
84+
body:
85+
proc do |req_body|
86+
parsed_body = JSON.parse(req_body, symbolize_names: true)
87+
true
88+
end,
89+
headers: {
90+
"Content-Type" => "application/json",
91+
"Authorization" => "Bearer ABC",
92+
},
93+
).to_return(status: 200, body: body.split("|"))
94+
95+
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
96+
end
97+
98+
expected = <<~TEXT
99+
<function_calls>
100+
<invoke>
101+
<tool_name>google</tool_name>
102+
<parameters><query>who is sam saffron</query>
103+
</parameters>
104+
<tool_id>tool_0</tool_id>
105+
</invoke>
106+
</function_calls>
107+
TEXT
108+
109+
expect(result.strip).to eq(expected.strip)
110+
111+
expected = {
112+
model: "command-r-plus",
113+
preamble: "You are hello bot",
114+
chat_history: [
115+
{ role: "USER", message: "user1: hello" },
116+
{ role: "CHATBOT", message: "hi user" },
117+
],
118+
message: "user1: thanks",
119+
tools: [
120+
{
121+
name: "google",
122+
description: "Will search using Google",
123+
parameter_definitions: {
124+
query: {
125+
description: "The search query",
126+
type: "str",
127+
required: true,
128+
},
129+
},
130+
},
131+
],
132+
force_single_step: false,
133+
stream: true,
134+
}
135+
136+
expect(parsed_body).to eq(expected)
137+
end
138+
62139
it "is able to run tools" do
63140
body = {
64141
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
@@ -99,20 +176,6 @@
99176
result = llm.generate(prompt_with_tool_results, user: user)
100177

101178
expect(parsed_body[:preamble]).to include("You are weather bot")
102-
expect(parsed_body[:preamble]).to include("<tools>")
103-
104-
expected_message = <<~MESSAGE
105-
<function_results>
106-
<result>
107-
<tool_name>weather</tool_name>
108-
<json>
109-
{"weather":"22c"}
110-
</json>
111-
</result>
112-
</function_results>
113-
MESSAGE
114-
115-
expect(parsed_body[:message].strip).to eq(expected_message.strip)
116179

117180
expect(result).to eq("Sydney is 22c")
118181
audit = AiApiAuditLog.order("id desc").first

0 commit comments

Comments
 (0)