Skip to content

Commit 7f16d3a

Browse files
authored
FEATURE: Cohere Command R support (#558)
- Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list - Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key - Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including: - Translating request parameters to the Cohere API format - Parsing Cohere API responses - Supporting streaming and non-streaming completions - Supporting "tools" which allow the model to call back to discourse to lookup additional information - Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format - Added specs covering the new Cohere endpoint and dialect classes - Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality.
1 parent eb93b21 commit 7f16d3a

File tree

11 files changed

+484
-0
lines changed

11 files changed

+484
-0
lines changed

app/models/ai_api_audit_log.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module Provider
77
HuggingFaceTextGeneration = 3
88
Gemini = 4
99
Vllm = 5
10+
Cohere = 6
1011
end
1112
end
1213

config/locales/client.en.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ en:
268268
claude-3-opus: "Claude 3 Opus"
269269
claude-3-sonnet: "Claude 3 Sonnet"
270270
claude-3-haiku: "Claude 3 Haiku"
271+
cohere-command-r-plus: "Cohere Command R Plus"
271272
gpt-4: "GPT-4"
272273
gpt-4-turbo: "GPT-4 Turbo"
273274
gpt-3:

config/locales/server.en.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ en:
5050
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
5151
ai_openai_api_key: "API key for OpenAI API"
5252
ai_anthropic_api_key: "API key for Anthropic API"
53+
ai_cohere_api_key: "API key for Cohere API"
5354
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
5455
ai_hugging_face_api_key: API key for Hugging Face API
5556
ai_hugging_face_token_limit: Max tokens Hugging Face API can use per request

config/settings.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ discourse_ai:
110110
ai_anthropic_api_key:
111111
default: ""
112112
secret: true
113+
ai_cohere_api_key:
114+
default: ""
115+
secret: true
113116
ai_stability_api_key:
114117
default: ""
115118
secret: true
@@ -336,6 +339,7 @@ discourse_ai:
336339
- claude-3-opus
337340
- claude-3-sonnet
338341
- claude-3-haiku
342+
- cohere-command-r-plus
339343
ai_bot_add_to_header:
340344
default: true
341345
client: true

lib/ai_bot/bot.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def self.guess_model(bot_user)
180180
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
181181
# no bedrock support yet 18-03
182182
"anthropic:claude-3-opus"
183+
when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS
184+
"cohere:command-r-plus"
183185
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
184186
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
185187
"claude-3-sonnet",

lib/ai_bot/entry_point.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class EntryPoint
1717
CLAUDE_3_OPUS_ID = -117
1818
CLAUDE_3_SONNET_ID = -118
1919
CLAUDE_3_HAIKU_ID = -119
20+
COHERE_COMMAND_R_PLUS = -120
2021

2122
BOTS = [
2223
[GPT4_ID, "gpt4_bot", "gpt-4"],
@@ -29,6 +30,7 @@ class EntryPoint
2930
[CLAUDE_3_OPUS_ID, "claude_3_opus_bot", "claude-3-opus"],
3031
[CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"],
3132
[CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"],
33+
[COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"],
3234
]
3335

3436
BOT_USER_IDS = BOTS.map(&:first)
@@ -67,6 +69,8 @@ def self.map_bot_model_to_user_id(model_name)
6769
CLAUDE_3_SONNET_ID
6870
in "claude-3-haiku"
6971
CLAUDE_3_HAIKU_ID
72+
in "cohere-command-r-plus"
73+
COHERE_COMMAND_R_PLUS
7074
else
7175
nil
7276
end

lib/completions/dialects/command.rb

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# frozen_string_literal: true
2+
3+
# see: https://docs.cohere.com/reference/chat
4+
#
5+
module DiscourseAi
6+
module Completions
7+
module Dialects
8+
class Command < Dialect
9+
class << self
10+
def can_translate?(model_name)
11+
%w[command-light command command-r command-r-plus].include?(model_name)
12+
end
13+
14+
def tokenizer
15+
DiscourseAi::Tokenizer::OpenAiTokenizer
16+
end
17+
end
18+
19+
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
20+
21+
def translate
22+
messages = prompt.messages
23+
24+
# ChatGPT doesn't use an assistant msg to improve long-context responses.
25+
if messages.last[:type] == :model
26+
messages = messages.dup
27+
messages.pop
28+
end
29+
30+
trimmed_messages = trim_messages(messages)
31+
32+
chat_history = []
33+
system_message = nil
34+
35+
prompt = {}
36+
37+
trimmed_messages.each do |msg|
38+
case msg[:type]
39+
when :system
40+
if system_message
41+
chat_history << { role: "SYSTEM", message: msg[:content] }
42+
else
43+
system_message = msg[:content]
44+
end
45+
when :model
46+
chat_history << { role: "CHATBOT", message: msg[:content] }
47+
when :tool_call
48+
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
49+
when :tool
50+
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
51+
when :user
52+
user_message = { role: "USER", message: msg[:content] }
53+
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
54+
chat_history << user_message
55+
end
56+
end
57+
58+
tools_prompt = build_tools_prompt
59+
prompt[:preamble] = +"#{system_message}"
60+
if tools_prompt.present?
61+
prompt[:preamble] << "\n#{tools_prompt}"
62+
prompt[
63+
:preamble
64+
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
65+
end
66+
67+
prompt[:chat_history] = chat_history if chat_history.present?
68+
69+
chat_history.reverse_each do |msg|
70+
if msg[:role] == "USER"
71+
prompt[:message] = msg[:message]
72+
chat_history.delete(msg)
73+
break
74+
end
75+
end
76+
77+
prompt
78+
end
79+
80+
def max_prompt_tokens
81+
case model_name
82+
when "command-light"
83+
4096
84+
when "command"
85+
8192
86+
when "command-r"
87+
131_072
88+
when "command-r-plus"
89+
131_072
90+
else
91+
8192
92+
end
93+
end
94+
95+
private
96+
97+
def per_message_overhead
98+
0
99+
end
100+
101+
def calculate_message_token(context)
102+
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
103+
end
104+
end
105+
end
106+
end
107+
end

lib/completions/dialects/dialect.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def dialect_for(model_name)
1717
DiscourseAi::Completions::Dialects::Gemini,
1818
DiscourseAi::Completions::Dialects::Mixtral,
1919
DiscourseAi::Completions::Dialects::Claude,
20+
DiscourseAi::Completions::Dialects::Command,
2021
]
2122

2223
if Rails.env.test? || Rails.env.development?

lib/completions/endpoints/base.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def endpoint_for(provider_name, model_name)
1616
DiscourseAi::Completions::Endpoints::Gemini,
1717
DiscourseAi::Completions::Endpoints::Vllm,
1818
DiscourseAi::Completions::Endpoints::Anthropic,
19+
DiscourseAi::Completions::Endpoints::Cohere,
1920
]
2021

2122
if Rails.env.test? || Rails.env.development?

lib/completions/endpoints/cohere.rb

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
module Endpoints
6+
class Cohere < Base
7+
class << self
8+
def can_contact?(endpoint_name, model_name)
9+
return false unless endpoint_name == "cohere"
10+
11+
%w[command-light command command-r command-r-plus].include?(model_name)
12+
end
13+
14+
def dependant_setting_names
15+
%w[ai_cohere_api_key]
16+
end
17+
18+
def correctly_configured?(model_name)
19+
SiteSetting.ai_cohere_api_key.present?
20+
end
21+
22+
def endpoint_name(model_name)
23+
"Cohere - #{model_name}"
24+
end
25+
end
26+
27+
def normalize_model_params(model_params)
28+
model_params = model_params.dup
29+
30+
model_params[:p] = model_params.delete(:top_p) if model_params[:top_p]
31+
32+
model_params
33+
end
34+
35+
def default_options(dialect)
36+
options = { model: "command-r-plus" }
37+
38+
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
39+
options
40+
end
41+
42+
def provider_id
43+
AiApiAuditLog::Provider::Cohere
44+
end
45+
46+
private
47+
48+
def model_uri
49+
URI("https://api.cohere.ai/v1/chat")
50+
end
51+
52+
def prepare_payload(prompt, model_params, dialect)
53+
payload = default_options(dialect).merge(model_params).merge(prompt)
54+
55+
payload[:stream] = true if @streaming_mode
56+
57+
payload
58+
end
59+
60+
def prepare_request(payload)
61+
headers = {
62+
"Content-Type" => "application/json",
63+
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}",
64+
}
65+
66+
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
67+
end
68+
69+
def extract_completion_from(response_raw)
70+
parsed = JSON.parse(response_raw, symbolize_names: true)
71+
72+
if @streaming_mode
73+
if parsed[:event_type] == "text-generation"
74+
parsed[:text]
75+
else
76+
if parsed[:event_type] == "stream-end"
77+
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
78+
@output_tokens = parsed.dig(:response, :meta, :billed_units, :output_tokens)
79+
end
80+
nil
81+
end
82+
else
83+
@input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
84+
@output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
85+
parsed[:text].to_s
86+
end
87+
end
88+
89+
def final_log_update(log)
90+
log.request_tokens = @input_tokens if @input_tokens
91+
log.response_tokens = @output_tokens if @output_tokens
92+
end
93+
94+
def partials_from(decoded_chunk)
95+
decoded_chunk.split("\n").compact
96+
end
97+
98+
def extract_prompt_for_tokenizer(prompt)
99+
text = +""
100+
if prompt[:chat_history]
101+
text << prompt[:chat_history]
102+
.map { |message| message[:content] || message["content"] || "" }
103+
.join("\n")
104+
end
105+
106+
text << prompt[:message] if prompt[:message]
107+
text << prompt[:preamble] if prompt[:preamble]
108+
109+
text
110+
end
111+
end
112+
end
113+
end
114+
end

0 commit comments

Comments
 (0)