Skip to content

Commit 05858ad

Browse files
committed
DRAFT: Create AI Bot users dynamically and support custom LlmModels
1 parent f642a27 commit 05858ad

36 files changed

+282
-233
lines changed

app/controllers/discourse_ai/ai_bot/bot_controller.rb

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ def stop_streaming_response
3131
end
3232

3333
def show_bot_username
34-
bot_user_id = DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(params[:username])
35-
raise Discourse::InvalidParameters.new(:username) if !bot_user_id
34+
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model(params[:username])
35+
raise Discourse::InvalidParameters.new(:username) if !bot_user
3636

37-
bot_username_lower = User.find(bot_user_id).username_lower
38-
39-
render json: { bot_username: bot_username_lower }, status: 200
37+
render json: { bot_username: bot_user.username_lower }, status: 200
4038
end
4139
end
4240
end

app/models/shared_ai_conversation.rb

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,10 @@ def formatted_excerpt
133133
end
134134

135135
def self.build_conversation_data(topic, max_posts: DEFAULT_MAX_POSTS, include_usernames: false)
136-
llm_name = nil
137-
topic.topic_allowed_users.each do |tu|
138-
if DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS.include?(tu.user_id)
139-
llm_name = DiscourseAi::AiBot::EntryPoint.find_bot_by_id(tu.user_id)&.llm
140-
end
141-
end
136+
allowed_user_ids = topic.topic_allowed_users.pluck(:user_id)
137+
ai_bot_participant = DiscourseAi::AiBot::EntryPoint.find_participant_in(allowed_user_ids)
138+
139+
llm_name = ai_bot_participant&.llm
142140

143141
llm_name = ActiveSupport::Inflector.humanize(llm_name) if llm_name
144142
llm_name ||= I18n.t("discourse_ai.unknown_model")
@@ -170,9 +168,7 @@ def self.build_conversation_data(topic, max_posts: DEFAULT_MAX_POSTS, include_us
170168
cooked: post.cooked,
171169
}
172170

173-
mapped[:persona] = persona if ::DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS.include?(
174-
post.user_id,
175-
)
171+
mapped[:persona] = persona if ai_bot_participant&.id == post.user_id
176172
mapped[:username] = post.user&.username if include_usernames
177173
mapped
178174
end,

config/settings.yml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,7 @@ discourse_ai:
354354
type: list
355355
default: "gpt-3.5-turbo"
356356
client: true
357-
choices:
358-
- gpt-3.5-turbo
359-
- gpt-4
360-
- gpt-4-turbo
361-
- gpt-4o
362-
- claude-2
363-
- gemini-1.5-pro
364-
- mixtral-8x7B-Instruct-V0.1
365-
- claude-3-opus
366-
- claude-3-sonnet
367-
- claude-3-haiku
368-
- cohere-command-r-plus
357+
choices: "DiscourseAi::Configuration::LlmEnumerator.ai_bot_models"
369358
ai_bot_add_to_header:
370359
default: true
371360
client: true
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# frozen_string_literal: true
2+
3+
class CustomFieldsToTrackAiBotUsers < ActiveRecord::Migration[7.0]
4+
def up
5+
existing_bot_user_ids = DB.query_single("SELECT id FROM users WHERE id <= -110 AND id >= -121")
6+
7+
custom_field_rows =
8+
existing_bot_user_ids
9+
.map { |id| "(bot_model_name, #{id_to_model_name(id)}, #{id})" }
10+
.join(",")
11+
12+
DB.exec(<<~SQL, rows: custom_field_rows) if custom_field_rows.present?
13+
INSERT INTO user_custom_fields (name, value, user_id)
14+
VALUES :rows;
15+
SQL
16+
end
17+
18+
def id_to_model_name(id)
19+
# Skip -116. fake model.
20+
case id
21+
when -110
22+
"gpt-4"
23+
when -111
24+
"gpt-3.5-turbo"
25+
when -112
26+
"claude-2"
27+
when -113
28+
"gpt-4-turbo"
29+
when -114
30+
"mixtral-8x7B-Instruct-V0.1"
31+
when -115
32+
"gemini-1.5-pro"
33+
when -116
34+
"fake"
35+
when -117
36+
"claude-3-opus"
37+
when -118
38+
"claude-3-sonnet"
39+
when -119
40+
"claude-3-haiku"
41+
when -120
42+
"cohere-command-r-plus"
43+
else
44+
"gpt-4o"
45+
end
46+
end
47+
48+
def down
49+
raise ActiveRecord::IrreversibleMigration
50+
end
51+
end

lib/ai_bot/bot.rb

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -163,76 +163,55 @@ def invoke_tool(tool, llm, cancel, context, &update_blk)
163163
end
164164

165165
def self.guess_model(bot_user)
166-
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
167-
guess =
168-
case bot_user.id
169-
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
170-
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
171-
"aws_bedrock:claude-2"
172-
else
173-
"anthropic:claude-2"
174-
end
175-
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
176-
"open_ai:gpt-4"
177-
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
178-
"open_ai:gpt-4-turbo"
179-
when DiscourseAi::AiBot::EntryPoint::GPT4O_ID
180-
"open_ai:gpt-4o"
181-
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
182-
"open_ai:gpt-3.5-turbo-16k"
183-
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
184-
mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
166+
base_choice = guess_base_model(bot_user)
167+
168+
if base_choice
169+
provider, model_name = base_choice.split(":")
170+
llm_model = LlmModel.find_by(provider: provider, name: model_name)
171+
172+
return "custom:#{llm_model.id}" if llm_model
173+
end
174+
175+
base_choice
176+
end
177+
178+
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
179+
def self.guess_base_model(bot_user)
180+
associated_llm =
181+
bot_user.custom_fields[DiscourseAi::AiBot::EntryPoint::BOT_MODEL_CUSTOM_FIELD]
182+
183+
return if associated_llm.nil? # Might be a persona user. Handled by constructor.
184+
185+
return "open_ai:gpt-3.5-turbo-16k" if associated_llm == "gpt-3.5-turbo"
186+
return "open_ai:#{associated_llm}" if associated_llm.starts_with?("gpt-")
187+
188+
return "google:gemini-1.5-pro" if associated_llm == "gemini-1.5-pro"
189+
190+
return "fake:fake" if associated_llm == "fake"
191+
192+
return "cohere:command-r-plus" if associated_llm = "command-r-plus"
193+
194+
if %w[claude-3-opus claude-3-sonnet claude-3-haiku].include?(associated_llm)
195+
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(associated_llm)
196+
return "aws_bedrock:#{associated_llm}"
197+
else
198+
return "anthropic:#{associated_llm}"
199+
end
200+
end
201+
202+
if associated_llm == "mistralai/Mixtral-8x7B-Instruct-v0.1"
203+
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(associated_llm)
185204
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
186-
"vllm:#{mixtral_model}"
205+
return "vllm:#{associated_llm}"
187206
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
188-
mixtral_model,
207+
associated_llm,
189208
)
190-
"hugging_face:#{mixtral_model}"
191-
else
192-
"ollama:mistral"
193-
end
194-
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
195-
"google:gemini-1.5-pro"
196-
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
197-
"fake:fake"
198-
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
199-
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
200-
"claude-3-opus",
201-
)
202-
"aws_bedrock:claude-3-opus"
203-
else
204-
"anthropic:claude-3-opus"
205-
end
206-
when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS
207-
"cohere:command-r-plus"
208-
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
209-
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
210-
"claude-3-sonnet",
211-
)
212-
"aws_bedrock:claude-3-sonnet"
209+
return "hugging_face:#{associated_llm}"
213210
else
214-
"anthropic:claude-3-sonnet"
211+
return "ollama:mistral"
215212
end
216-
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID
217-
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
218-
"claude-3-haiku",
219-
)
220-
"aws_bedrock:claude-3-haiku"
221-
else
222-
"anthropic:claude-3-haiku"
223-
end
224-
else
225-
nil
226213
end
227-
228-
if guess
229-
provider, model_name = guess.split(":")
230-
llm_model = LlmModel.find_by(provider: provider, name: model_name)
231-
232-
return "custom:#{llm_model.id}" if llm_model
233214
end
234-
235-
guess
236215
end
237216

238217
def build_placeholder(summary, details, custom_raw: nil)

lib/ai_bot/entry_point.rb

Lines changed: 46 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,82 +6,58 @@ module AiBot
66

77
class EntryPoint
88
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
9-
10-
GPT4_ID = -110
11-
GPT3_5_TURBO_ID = -111
12-
CLAUDE_V2_ID = -112
13-
GPT4_TURBO_ID = -113
14-
MIXTRAL_ID = -114
15-
GEMINI_ID = -115
16-
FAKE_ID = -116 # only used for dev and test
17-
CLAUDE_3_OPUS_ID = -117
18-
CLAUDE_3_SONNET_ID = -118
19-
CLAUDE_3_HAIKU_ID = -119
20-
COHERE_COMMAND_R_PLUS = -120
21-
GPT4O_ID = -121
22-
23-
BOTS = [
24-
[GPT4_ID, "gpt4_bot", "gpt-4"],
25-
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
26-
[CLAUDE_V2_ID, "claude_bot", "claude-2"],
27-
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
28-
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
29-
[GEMINI_ID, "gemini_bot", "gemini-1.5-pro"],
30-
[FAKE_ID, "fake_bot", "fake"],
31-
[CLAUDE_3_OPUS_ID, "claude_3_opus_bot", "claude-3-opus"],
32-
[CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"],
33-
[CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"],
34-
[COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"],
35-
[GPT4O_ID, "gpt4o_bot", "gpt-4o"],
36-
]
37-
38-
BOT_USER_IDS = BOTS.map(&:first)
39-
9+
BOT_MODEL_CUSTOM_FIELD = "bot_model_name"
4010
Bot = Struct.new(:id, :name, :llm)
4111

4212
def self.all_bot_ids
43-
BOT_USER_IDS.concat(AiPersona.mentionables.map { |mentionable| mentionable[:user_id] })
13+
mentionable_persona_user_ids =
14+
AiPersona.mentionables.map { |mentionable| mentionable[:user_id] }
15+
mentionable_bot_users =
16+
User
17+
.joins(:_custom_fields)
18+
.where(active: true, user_custom_fields: { name: BOT_MODEL_CUSTOM_FIELD })
19+
.pluck(:user_id)
20+
21+
mentionable_bot_users + mentionable_persona_user_ids
4422
end
4523

46-
def self.find_bot_by_id(id)
47-
found = DiscourseAi::AiBot::EntryPoint::BOTS.find { |bot| bot[0] == id }
48-
return if !found
49-
Bot.new(found[0], found[1], found[2])
24+
def self.find_participant_in(participant_ids)
25+
participant_data =
26+
UserCustomField
27+
.includes(:user)
28+
.where(users: { active: true }, name: BOT_MODEL_CUSTOM_FIELD, user_id: participant_ids)
29+
.last
30+
31+
return if participant_data.nil?
32+
33+
Bot.new(
34+
participant_data.user.id,
35+
participant_data.user.username_lower,
36+
participant_data.value,
37+
)
5038
end
5139

52-
def self.map_bot_model_to_user_id(model_name)
53-
case model_name
54-
in "gpt-4o"
55-
GPT4O_ID
56-
in "gpt-4-turbo"
57-
GPT4_TURBO_ID
58-
in "gpt-3.5-turbo"
59-
GPT3_5_TURBO_ID
60-
in "gpt-4"
61-
GPT4_ID
62-
in "claude-2"
63-
CLAUDE_V2_ID
64-
in "mixtral-8x7B-Instruct-V0.1"
65-
MIXTRAL_ID
66-
in "gemini-1.5-pro"
67-
GEMINI_ID
68-
in "fake"
69-
FAKE_ID
70-
in "claude-3-opus"
71-
CLAUDE_3_OPUS_ID
72-
in "claude-3-sonnet"
73-
CLAUDE_3_SONNET_ID
74-
in "claude-3-haiku"
75-
CLAUDE_3_HAIKU_ID
76-
in "cohere-command-r-plus"
77-
COHERE_COMMAND_R_PLUS
78-
else
79-
nil
80-
end
40+
def self.find_user_from_model(model_name)
41+
UserCustomField
42+
.includes(:user)
43+
.find_by(name: BOT_MODEL_CUSTOM_FIELD, value: model_name)
44+
&.user
45+
end
46+
47+
def self.enabled_user_ids_and_models_map
48+
enabled_models = SiteSetting.ai_bot_enabled_chat_bots.split("|")
49+
50+
DB.query_hash(<<~SQL, model_names: enabled_models, cf_name: BOT_MODEL_CUSTOM_FIELD)
51+
SELECT users.username AS username, users.id AS id, ucf.value AS model_name
52+
FROM user_custom_fields ucf
53+
INNER JOIN users ON ucf.user_id = users.id
54+
WHERE ucf.name = :cf_name
55+
AND ucf.value IN (:model_names)
56+
SQL
8157
end
8258

8359
# Most errors are simply "not_allowed"
84-
# we do not want to reveal information about this sytem
60+
# we do not want to reveal information about this system
8561
# the 2 exceptions are "other_people_in_pm" and "other_content_in_pm"
8662
# in both cases you have access to the PM so we are not revealing anything
8763
def self.ai_share_error(topic, guardian)
@@ -170,25 +146,11 @@ def inject_into(plugin)
170146
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
171147
end,
172148
) do
173-
model_map = {}
174-
SiteSetting
175-
.ai_bot_enabled_chat_bots
176-
.split("|")
177-
.each do |bot_name|
178-
model_map[
179-
::DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(bot_name)
180-
] = bot_name
181-
end
182-
183-
# not 100% ideal, cause it is one extra query, but we need it
184-
bots = DB.query_hash(<<~SQL, user_ids: model_map.keys)
185-
SELECT username, id FROM users WHERE id IN (:user_ids)
186-
SQL
149+
bots_map = ::DiscourseAi::AiBot::EntryPoint.enabled_user_ids_and_models_map
187150

188-
bots.each { |hash| hash["model_name"] = model_map[hash["id"]] }
189151
persona_users = AiPersona.persona_users(user: scope.user)
190152
if persona_users.present?
191-
bots.concat(
153+
bots_map.concat(
192154
persona_users.map do |persona_user|
193155
{
194156
"id" => persona_user[:user_id],
@@ -198,7 +160,8 @@ def inject_into(plugin)
198160
end,
199161
)
200162
end
201-
bots
163+
164+
bots_map
202165
end
203166

204167
plugin.add_to_serializer(:current_user, :can_use_assistant) do

0 commit comments

Comments
 (0)