From 992ae0448f8a028086d95d0beeaeeb9775987401 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 12 May 2025 16:36:39 -0300 Subject: [PATCH] FEATURE: Examples support for personas. Examples simulate previous interactions with an LLM and come right after the system prompt. This helps grounding the model and producing better responses. --- .../admin/ai_personas_controller.rb | 10 ++ app/models/ai_persona.rb | 19 +++ .../localized_ai_persona_serializer.rb | 3 +- .../discourse/admin/models/ai-persona.js | 3 +- .../components/ai-persona-editor.gjs | 33 ++++ .../components/ai-persona-example.gjs | 67 ++++++++ config/locales/client.en.yml | 8 + config/locales/server.en.yml | 3 + db/fixtures/personas/603_ai_personas.rb | 2 + ...20250508154953_add_examples_to_personas.rb | 7 + lib/personas/persona.rb | 30 +++- lib/personas/summarizer.rb | 9 ++ lib/summarization/strategies/topic_summary.rb | 17 +- spec/lib/personas/persona_spec.rb | 24 +++ spec/models/ai_persona_spec.rb | 148 ++++++++++-------- .../admin/ai_personas_controller_spec.rb | 2 + 16 files changed, 295 insertions(+), 90 deletions(-) create mode 100644 assets/javascripts/discourse/components/ai-persona-example.gjs create mode 100644 db/migrate/20250508154953_add_examples_to_personas.rb diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 58a61e0e0..2b57dda9a 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -225,6 +225,10 @@ def ai_persona_params permitted[:response_format] = permit_response_format(response_format) end + if examples = params.dig(:ai_persona, :examples) + permitted[:examples] = permit_examples(examples) + end + permitted end @@ -251,6 +255,12 @@ def permit_response_format(response_format) end end end + + def permit_examples(examples) + return [] if !examples.is_a?(Array) + + examples.map { |example_arr| example_arr.take(2).map(&:to_s) } + end end end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 8347a4140..8793b398f 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -13,6 +13,7 @@ class AiPersona < ActiveRecord::Base validate :system_persona_unchangeable, on: :update, if: :system validate :chat_preconditions validate :allowed_seeded_model, if: :default_llm_id + validate :well_formated_examples validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true # leaves some room for growth but sets a maximum to avoid memory issues # we may want to revisit this in the future @@ -265,6 +266,7 @@ def class_instance define_method(:top_p) { @ai_persona&.top_p } define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." } define_method(:uploads) { @ai_persona&.uploads } + define_method(:examples) { @ai_persona&.examples } end end @@ -343,6 +345,11 @@ def system_persona_unchangeable new_format = response_format_change[1].map { |f| f["key"] }.to_set errors.add(:base, error_msg) if old_format != new_format + elsif examples_changed? + old_examples = examples_change[0].flatten.to_set + new_examples = examples_change[1].flatten.to_set + + errors.add(:base, error_msg) if old_examples != new_examples end end @@ -363,6 +370,17 @@ def allowed_seeded_model errors.add(:default_llm, I18n.t("discourse_ai.llm.configuration.invalid_seeded_model")) end + + def well_formated_examples + return if examples.blank? + + if examples.is_a?(Array) && + examples.all? { |e| e.is_a?(Array) && e.length == 2 && e.all?(&:present?) } + return + end + + errors.add(:examples, I18n.t("discourse_ai.personas.malformed_examples")) + end end # == Schema Information @@ -401,6 +419,7 @@ def allowed_seeded_model # default_llm_id :bigint # question_consolidator_llm_id :bigint # response_format :jsonb +# examples :jsonb # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index 57945e291..11b9d8158 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -31,7 +31,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :allow_topic_mentions, :allow_personal_messages, :force_default_llm, - :response_format + :response_format, + :examples has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index f313dee5d..042bc786d 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -34,6 +34,7 @@ const CREATE_ATTRIBUTES = [ "allow_chat_channel_mentions", "allow_chat_direct_messages", "response_format", + "examples", ]; const SYSTEM_ATTRIBUTES = [ @@ -61,7 +62,6 @@ const SYSTEM_ATTRIBUTES = [ "allow_topic_mentions", "allow_chat_channel_mentions", "allow_chat_direct_messages", - "response_format", ]; export default class AiPersona extends RestModel { @@ -154,6 +154,7 @@ export default class AiPersona extends RestModel { this.populateTools(attrs); attrs.forced_tool_count = this.forced_tool_count || -1; attrs.response_format = attrs.response_format || []; + attrs.examples = attrs.examples || []; return attrs; } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index f0b56d8a8..caf9c4031 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -17,6 +17,7 @@ import AdminUser from "admin/models/admin-user"; import GroupChooser from "select-kit/components/group-chooser"; import AiPersonaResponseFormatEditor from "../components/modal/ai-persona-response-format-editor"; import AiLlmSelector from "./ai-llm-selector"; +import AiPersonaCollapsableExample from "./ai-persona-example"; import AiPersonaToolOptions from "./ai-persona-tool-options"; import AiToolSelector from "./ai-tool-selector"; import RagOptionsFk from "./rag-options-fk"; @@ -230,6 +231,12 @@ export default class PersonaEditor extends Component { return this.allTools.filter((tool) => tools.includes(tool.id)); } + @action + addExamplesPair(form, data) { + const newExamples = [...data.examples, ["", ""]]; + form.set("examples", newExamples); + } + mapToolOptions(currentOptions, toolNames) { const updatedOptions = Object.assign({}, currentOptions); @@ -422,6 +429,32 @@ export default class PersonaEditor extends Component { {{/unless}} + + {{#unless data.system}} + + + + {{/unless}} + + {{#if (gt data.examples.length 0)}} + + + + {{/if}} + + +
+ {{icon this.caretIcon}} + {{this.exampleTitle}} +
+ {{#unless this.collapsed}} + <@examplesCollection.Collection as |exPair pairIdx|> + + + + + + {{#unless @system}} + <@form.Container> + <@form.Button + @action={{this.deletePair}} + @label="discourse_ai.ai_persona.examples.remove" + class="ai-persona-editor__delete_example btn-danger" + /> + + {{/unless}} + {{/unless}} + +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 04901dc26..a043ae315 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -330,6 +330,14 @@ en: modal: root_title: "Response structure" key_title: "Key" + examples: + title: Examples + examples_help: Simulate previous interactions with the LLM and ground it to produce better result. + new: New example + remove: Delete example + collapsable_title: "Example #%{number}" + user: "User message" + model: "Model response" list: enabled: "AI Bot?" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 1c3351c9f..00a71cd31 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -495,6 +495,9 @@ en: other: "We couldn't delete this model because %{settings} are using it. Update the settings and try again." cannot_edit_builtin: "You can't edit a built-in model." + personas: + malformed_examples: "The given examples have the wrong format." + embeddings: delete_failed: "This model is currently in use. Update the `ai embeddings selected model` first." cannot_edit_builtin: "You can't edit a built-in model." diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb index 7d52e8a9e..27e6d479f 100644 --- a/db/fixtures/personas/603_ai_personas.rb +++ b/db/fixtures/personas/603_ai_personas.rb @@ -74,6 +74,8 @@ def from_setting(setting_name) persona.response_format = instance.response_format + persona.examples = instance.examples + persona.system_prompt = instance.system_prompt persona.top_p = instance.top_p persona.temperature = instance.temperature diff --git a/db/migrate/20250508154953_add_examples_to_personas.rb b/db/migrate/20250508154953_add_examples_to_personas.rb new file mode 100644 index 000000000..2cf12912c --- /dev/null +++ b/db/migrate/20250508154953_add_examples_to_personas.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class AddExamplesToPersonas < ActiveRecord::Migration[7.2] + def change + add_column :ai_personas, :examples, :jsonb + end +end diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index a8b087850..a4443dcc0 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -164,6 +164,10 @@ def response_format nil end + def examples + [] + end + def available_tools self .class @@ -173,11 +177,7 @@ def available_tools end def craft_prompt(context, llm: nil) - system_insts = - system_prompt.gsub(/\{(\w+)\}/) do |match| - found = context.lookup_template_param(match[1..-2]) - found.nil? ? match : found.to_s - end + system_insts = replace_placeholders(system_prompt, context) prompt_insts = <<~TEXT.strip #{system_insts} @@ -206,10 +206,21 @@ def craft_prompt(context, llm: nil) prompt_insts << fragments_guidance if fragments_guidance.present? + post_system_examples = [] + + if examples.present? + examples.flatten.each_with_index do |e, idx| + post_system_examples << { + content: replace_placeholders(e, context), + type: (idx + 1).odd? ? :user : :model, + } + end + end + prompt = DiscourseAi::Completions::Prompt.new( prompt_insts, - messages: context.messages, + messages: post_system_examples.concat(context.messages), topic_id: context.topic_id, post_id: context.post_id, ) @@ -239,6 +250,13 @@ def allow_partial_tool_calls? protected + def replace_placeholders(content, context) + content.gsub(/\{(\w+)\}/) do |match| + found = context.lookup_template_param(match[1..-2]) + found.nil? ? match : found.to_s + end + end + def tool_instance(tool_call, bot_user:, llm:, context:, existing_tools:) function_id = tool_call.id function_name = tool_call.name diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb index 64540ff08..c1eefe89c 100644 --- a/lib/personas/summarizer.rb +++ b/lib/personas/summarizer.rb @@ -32,6 +32,15 @@ def system_prompt def response_format [{ key: "summary", type: "string" }] end + + def examples + [ + [ + "Here are the posts inside XML tags:\n\n1) user1 said: I love Mondays 2) user2 said: I hate Mondays\n\nGenerate a concise, coherent summary of the text above maintaining the original language.", + "Two users are sharing their feelings toward Mondays. [user1]({resource_url}/1) hates them, while [user2]({resource_url}/2) loves them.", + ], + ] + end end end end diff --git a/lib/summarization/strategies/topic_summary.rb b/lib/summarization/strategies/topic_summary.rb index 9361e50ba..f06b67abd 100644 --- a/lib/summarization/strategies/topic_summary.rb +++ b/lib/summarization/strategies/topic_summary.rb @@ -44,20 +44,7 @@ def as_llm_messages(contents) input = contents.map { |item| "(#{item[:id]} #{item[:poster]} said: #{item[:text]} " }.join - messages = [] - messages << { - type: :user, - content: - "Here are the posts inside XML tags:\n\n1) user1 said: I love Mondays 2) user2 said: I hate Mondays\n\nGenerate a concise, coherent summary of the text above maintaining the original language.", - } - - messages << { - type: :model, - content: - "Two users are sharing their feelings toward Mondays. [user1](#{resource_path}/1) hates them, while [user2](#{resource_path}/2) loves them.", - } - - messages << { type: :user, content: <<~TEXT.strip } + [{ type: :user, content: <<~TEXT.strip }] #{content_title.present? ? "The discussion title is: " + content_title + ".\n" : ""} Here are the posts, inside XML tags: @@ -67,8 +54,6 @@ def as_llm_messages(contents) Generate a concise, coherent summary of the text above maintaining the original language. TEXT - - messages end private diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb index 25f12914b..e3c96e348 100644 --- a/spec/lib/personas/persona_spec.rb +++ b/spec/lib/personas/persona_spec.rb @@ -8,6 +8,7 @@ def tools DiscourseAi::Personas::Tools::Image, ] end + def system_prompt <<~PROMPT {site_url} @@ -445,6 +446,29 @@ def stub_fragments(fragment_count, persona: ai_persona) expect(crafted_system_prompt).not_to include("fragment-n10") # Fragment #10 not included end end + + context "when the persona has examples" do + fab!(:examples_persona) do + Fabricate( + :ai_persona, + examples: [["User message", "assistant response"]], + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + ) + end + + it "includes them before the context messages" do + custom_persona = + DiscourseAi::Personas::Persona.find_by(id: examples_persona.id, user: user).new + + post_system_prompt_msgs = custom_persona.craft_prompt(with_cc).messages.last(3) + + expect(post_system_prompt_msgs).to contain_exactly( + { content: "User message", type: :user }, + { content: "assistant response", type: :model }, + { content: "Tell me the time", type: :user }, + ) + end + end end end end diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb index c776e699e..0e6b9d13d 100644 --- a/spec/models/ai_persona_spec.rb +++ b/spec/models/ai_persona_spec.rb @@ -1,90 +1,79 @@ # frozen_string_literal: true RSpec.describe AiPersona do + subject(:basic_persona) do + AiPersona.new( + name: "test", + description: "test", + system_prompt: "test", + tools: [], + allowed_group_ids: [], + ) + end + fab!(:llm_model) fab!(:seeded_llm_model) { Fabricate(:llm_model, id: -1) } it "validates context settings" do - persona = - AiPersona.new( - name: "test", - description: "test", - system_prompt: "test", - tools: [], - allowed_group_ids: [], - ) + expect(basic_persona.valid?).to eq(true) - expect(persona.valid?).to eq(true) - - persona.max_context_posts = 0 - expect(persona.valid?).to eq(false) - expect(persona.errors[:max_context_posts]).to eq(["must be greater than 0"]) + basic_persona.max_context_posts = 0 + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:max_context_posts]).to eq(["must be greater than 0"]) - persona.max_context_posts = 1 - expect(persona.valid?).to eq(true) + basic_persona.max_context_posts = 1 + expect(basic_persona.valid?).to eq(true) - persona.max_context_posts = nil - expect(persona.valid?).to eq(true) + basic_persona.max_context_posts = nil + expect(basic_persona.valid?).to eq(true) end it "validates tools" do - persona = - AiPersona.new( - name: "test", - description: "test", - system_prompt: "test", - tools: [], - allowed_group_ids: [], - ) - Fabricate(:ai_tool, id: 1) Fabricate(:ai_tool, id: 2, name: "Archie search", tool_name: "search") - expect(persona.valid?).to eq(true) + expect(basic_persona.valid?).to eq(true) - persona.tools = %w[search image_generation] - expect(persona.valid?).to eq(true) + basic_persona.tools = %w[search image_generation] + expect(basic_persona.valid?).to eq(true) - persona.tools = %w[search image_generation search] - expect(persona.valid?).to eq(false) - expect(persona.errors[:tools]).to eq(["Can not have duplicate tools"]) + basic_persona.tools = %w[search image_generation search] + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:tools]).to eq(["Can not have duplicate tools"]) - persona.tools = [["custom-1", { test: "test" }, false], ["custom-2", { test: "test" }, false]] - expect(persona.valid?).to eq(true) - expect(persona.errors[:tools]).to eq([]) + basic_persona.tools = [ + ["custom-1", { test: "test" }, false], + ["custom-2", { test: "test" }, false], + ] + expect(basic_persona.valid?).to eq(true) + expect(basic_persona.errors[:tools]).to eq([]) - persona.tools = [["custom-1", { test: "test" }, false], ["custom-1", { test: "test" }, false]] - expect(persona.valid?).to eq(false) - expect(persona.errors[:tools]).to eq(["Can not have duplicate tools"]) + basic_persona.tools = [ + ["custom-1", { test: "test" }, false], + ["custom-1", { test: "test" }, false], + ] + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:tools]).to eq(["Can not have duplicate tools"]) - persona.tools = [ + basic_persona.tools = [ ["custom-1", { test: "test" }, false], ["custom-2", { test: "test" }, false], "image_generation", ] - expect(persona.valid?).to eq(true) - expect(persona.errors[:tools]).to eq([]) + expect(basic_persona.valid?).to eq(true) + expect(basic_persona.errors[:tools]).to eq([]) - persona.tools = [ + basic_persona.tools = [ ["custom-1", { test: "test" }, false], ["custom-2", { test: "test" }, false], "Search", ] - expect(persona.valid?).to eq(false) - expect(persona.errors[:tools]).to eq(["Can not have duplicate tools"]) + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:tools]).to eq(["Can not have duplicate tools"]) end it "allows creation of user" do - persona = - AiPersona.create!( - name: "test", - description: "test", - system_prompt: "test", - tools: [], - allowed_group_ids: [], - ) - - user = persona.create_user! + user = basic_persona.create_user! expect(user.username).to eq("test_bot") expect(user.name).to eq("Test") expect(user.bot?).to be(true) @@ -223,25 +212,17 @@ end it "validates allowed seeded model" do - persona = - AiPersona.new( - name: "test", - description: "test", - system_prompt: "test", - tools: [], - allowed_group_ids: [], - default_llm_id: seeded_llm_model.id, - ) + basic_persona.default_llm_id = seeded_llm_model.id SiteSetting.ai_bot_allowed_seeded_models = "" - expect(persona.valid?).to eq(false) - expect(persona.errors[:default_llm]).to include( + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:default_llm]).to include( I18n.t("discourse_ai.llm.configuration.invalid_seeded_model"), ) SiteSetting.ai_bot_allowed_seeded_models = "-1" - expect(persona.valid?).to eq(true) + expect(basic_persona.valid?).to eq(true) end it "does not leak caches between sites" do @@ -268,6 +249,7 @@ system_prompt: "system persona", tools: %w[Search Time], response_format: [{ key: "summary", type: "string" }], + examples: [%w[user_msg1 assistant_msg1], %w[user_msg2 assistant_msg2]], system: true, ) end @@ -302,6 +284,40 @@ ActiveRecord::RecordInvalid, ) end + + it "doesn't accept changes to examples" do + other_examples = [%w[user_msg1 assistant_msg1]] + + expect { system_persona.update!(examples: other_examples) }.to raise_error( + ActiveRecord::RecordInvalid, + ) + end + end + end + + describe "validates examples format" do + it "doesn't accept examples that are not arrays" do + basic_persona.examples = [1] + + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:examples].first).to eq( + I18n.t("discourse_ai.personas.malformed_examples"), + ) + end + + it "doesn't accept examples that don't come in pairs" do + basic_persona.examples = [%w[user_msg1]] + + expect(basic_persona.valid?).to eq(false) + expect(basic_persona.errors[:examples].first).to eq( + I18n.t("discourse_ai.personas.malformed_examples"), + ) + end + + it "works when example is well formatted" do + basic_persona.examples = [%w[user_msg1 assistant1]] + + expect(basic_persona.valid?).to eq(true) end end end diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 0e5662f5e..17b7fe34e 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -186,6 +186,7 @@ question_consolidator_llm_id: llm_model.id, forced_tool_count: 2, response_format: [{ key: "summary", type: "string" }], + examples: [%w[user_msg1 assistant_msg1], %w[user_msg2 assistant_msg2]], } end @@ -213,6 +214,7 @@ expect(persona_json["response_format"].map { |rf| rf["key"] }).to contain_exactly( "summary", ) + expect(persona_json["examples"]).to eq(valid_attributes[:examples]) persona = AiPersona.find(persona_json["id"])