Skip to content

Commit 31da869

Browse files
committed
Feat: model selection based on API result
- add API calls to ask for available models - fix new state: now stores chat and command agent separately
1 parent 974afa0 commit 31da869

File tree

10 files changed

+69
-169
lines changed

10 files changed

+69
-169
lines changed

lua/parrot/chat_handler.lua

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ function ChatHandler:_chat_respond(params)
656656
end
657657

658658
local query_prov = self:get_provider(true)
659+
query_prov:set_model(model_obj.name)
659660

660661
local llm_prefix = self.options.llm_prefix
661662
local llm_suffix = "[{{llm}}]"
@@ -761,8 +762,7 @@ function ChatHandler:_chat_respond(params)
761762
-- prepare invisible buffer for the model to write to
762763
local topic_buf = vim.api.nvim_create_buf(false, true)
763764
local topic_handler = chatutils.create_handler(self.queries, topic_buf, nil, 0, false, "", false)
764-
765-
topic_prov:check({ model = self.providers[topic_prov.name].topic_model })
765+
topic_prov:set_model(self.providers[topic_prov.name].topic.model)
766766

767767
local topic_spinner = nil
768768
if self.options.enable_spinner then
@@ -927,7 +927,7 @@ function ChatHandler:switch_provider(selected_prov, is_chat)
927927
end
928928

929929
if self.providers[selected_prov] then
930-
self.set_provider(selected_prov, is_chat)
930+
self:set_provider(selected_prov, is_chat)
931931
logger.info("Switched to provider: " .. selected_prov)
932932
return
933933
else
@@ -949,7 +949,7 @@ function ChatHandler:provider(params)
949949
prompt = "Provider selection ❯",
950950
fzf_opts = self.options.fzf_lua_opts,
951951
complete = function(selection)
952-
self:switch_provider(selection[1])
952+
self:switch_provider(selection[1], is_chat)
953953
end,
954954
})
955955
else
@@ -966,7 +966,6 @@ function ChatHandler:switch_model(is_chat, selected_model, prov)
966966
logger.warning("Empty model selection")
967967
return
968968
end
969-
prov:check(selected_model)
970969
if is_chat then
971970
self.state:set_model(prov.name, selected_model, "chat")
972971
logger.info("Chat model: " .. selected_model)
@@ -1281,6 +1280,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
12811280
end
12821281

12831282
-- call the model and write the response
1283+
prov:set_model(model_obj.name)
12841284

12851285
local spinner = nil
12861286
if self.options.enable_spinner then

lua/parrot/config.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ function M.setup(opts)
326326

327327
M.available_providers = vim.tbl_keys(M.providers)
328328

329-
available_models = {}
329+
local available_models = {}
330330
for _, prov_name in ipairs(M.available_providers) do
331-
local _prov = init_provider(prov_name, "", "")
331+
local _prov = init_provider(prov_name, M.providers[prov_name].endpoint, M.providers[prov_name].api_key)
332332
available_models[prov_name] = _prov:get_available_models()
333333
end
334334
M.available_models = available_models

lua/parrot/provider/anthropic.lua

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@ local utils = require("parrot.utils")
44
local Anthropic = {}
55
Anthropic.__index = Anthropic
66

7-
local available_model_set = {
8-
["claude-3-5-sonnet-20240620"] = true,
9-
["claude-3-opus-20240229"] = true,
10-
["claude-3-sonnet-20240229"] = true,
11-
["claude-3-haiku-20240307"] = true,
12-
}
13-
147
-- https://docs.anthropic.com/en/api/messages
158
local available_api_parameters = {
169
-- required
@@ -100,10 +93,6 @@ function Anthropic:process_onexit(res)
10093
end
10194
end
10295

103-
function Anthropic:check(model)
104-
return available_model_set[model]
105-
end
106-
10796
function Anthropic:get_available_models()
10897
return {
10998
"claude-3-5-sonnet-20240620",

lua/parrot/provider/gemini.lua

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@ local utils = require("parrot.utils")
44
local Gemini = {}
55
Gemini.__index = Gemini
66

7-
local available_model_set = {
8-
["gemini-1.5-flash"] = true,
9-
["gemini-1.5-pro"] = true,
10-
}
11-
127
-- https://ai.google.dev/gemini-api/docs/models/generative-models#model_parameters
138
local available_api_parameters = {
149
["contents"] = true,
@@ -120,14 +115,11 @@ function Gemini:process_onexit(res)
120115
end
121116
end
122117

123-
function Gemini:check(model)
124-
return available_model_set[model]
125-
end
126-
127118
function Gemini:get_available_models()
128119
return {
129120
"gemini-1.5-flash",
130121
"gemini-1.5-pro",
122+
"gemini-1.0-pro",
131123
}
132124
end
133125

lua/parrot/provider/groq.lua

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,10 @@
11
local logger = require("parrot.logger")
22
local utils = require("parrot.utils")
3+
local Job = require("plenary.job")
34

45
local Groq = {}
56
Groq.__index = Groq
67

7-
local available_model_set = {
8-
["llama-3.1-405b-reasoning"] = true,
9-
["llama-3.1-70b-versatile"] = true,
10-
["llama-3.1-8b-instant"] = true,
11-
["llama3-groq-70b-8192-tool-use-preview"] = true,
12-
["llama3-groq-8b-8192-tool-use-preview"] = true,
13-
["llama-guard-3-8b"] = true,
14-
["llama3-70b-8192"] = true,
15-
["llama3-8b-8192"] = true,
16-
["mixtral-8x7b-32768"] = true,
17-
["gemma-7b-it"] = true,
18-
["gemma2-9b-it"] = true,
19-
}
20-
218
-- https://console.groq.com/docs/api-reference#chat-create
229
local available_api_parameters = {
2310
-- required
@@ -100,18 +87,33 @@ function Groq:process_stdout(response)
10087
end
10188

10289
function Groq:process_onexit(res)
103-
-- '{"error":{"message":"Invalid API Key","type":"invalid_request_error","code":"invalid_api_key"}}'
10490
local success, parsed = pcall(vim.json.decode, res)
10591
if success and parsed.error then
10692
logger.error("Groq - message: " .. parsed.error.message)
10793
end
10894
end
10995

110-
function Groq:check(model)
111-
return available_model_set[model]
112-
end
113-
11496
function Groq:get_available_models()
97+
if self:verify() then
98+
Job:new({
99+
command = "curl",
100+
args = {
101+
"https://api.groq.com/openai/v1/models",
102+
"-H",
103+
"Authorization: Bearer " .. self.api_key,
104+
"-H",
105+
"Content-Type: application/json",
106+
},
107+
on_exit = function(job)
108+
local parsed_response = utils.parse_raw_response(job:result())
109+
local ids = {}
110+
for _, item in ipairs(vim.json.decode(parsed_response).data) do
111+
table.insert(ids, item.id)
112+
end
113+
return ids
114+
end,
115+
}):start()
116+
end
115117
return {
116118
"llama-3.1-405b-reasoning",
117119
"llama-3.1-70b-versatile",

lua/parrot/provider/mistral.lua

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,6 @@ local utils = require("parrot.utils")
44
local Mistral = {}
55
Mistral.__index = Mistral
66

7-
local available_model_set = {
8-
["codestral-latest"] = true,
9-
["mistral-tiny"] = true,
10-
["mistral-small-latest"] = true,
11-
["mistral-medium-latest"] = true,
12-
["mistral-large-latest"] = true,
13-
["open-mistral-7b"] = true,
14-
["open-mixtral-8x7b"] = true,
15-
["open-mixtral-8x22b"] = true,
16-
}
17-
187
-- https://docs.mistral.ai/api/#operation/createChatCompletion
198
local available_api_parameters = {
209
-- required
@@ -91,10 +80,6 @@ function Mistral:process_onexit(res)
9180
end
9281
end
9382

94-
function Mistral:check(model)
95-
return available_model_set[model]
96-
end
97-
9883
function Mistral:get_available_models()
9984
return {
10085
"codestral-latest",

lua/parrot/provider/ollama.lua

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -77,57 +77,28 @@ function Ollama:process_onexit(res)
7777
end
7878
end
7979

80-
function Ollama:check(model)
81-
if not self.ollama_installed then
82-
logger.warning("ollama not found.")
83-
return false
84-
end
85-
86-
local handle = io.popen("ollama list")
87-
local result = handle:read("*a")
88-
handle:close()
80+
function Ollama:get_available_models()
81+
-- curl https://api.openai.com/v1/models -H "Authorization: Bearer $OPENAI_API_KEY"
82+
local job = Job:new({
83+
command = "curl",
84+
args = { "-H", "Content-Type: application/json", "http://localhost:11434/api/tags" },
85+
}):sync()
8986

90-
local found_match = false
91-
for line in result:gmatch("[^\r\n]+") do
92-
if string.match(line, model) then
93-
found_match = true
94-
break
95-
end
87+
local parsed_response = utils.parse_raw_response(job)
88+
local success, parsed_data = pcall(vim.json.decode, parsed_response)
89+
if not success then
90+
logger.error("Error parsing JSON:" .. vim.inspect(parsed_data))
91+
return {}
9692
end
97-
98-
if not found_match then
99-
if not pcall(require, "plenary") then
100-
logger.error("Plenary not installed. Please install nvim-lua/plenary.nvim to use this plugin.")
101-
return false
102-
end
103-
local confirm = vim.fn.confirm("ollama model " .. model .. " not found. Download now?", "&Yes\n&No", 1)
104-
if confirm == 1 then
105-
local job = Job:new({
106-
command = "ollama",
107-
args = { "pull", model },
108-
on_exit = function(_, return_val)
109-
logger.info("Download finished with exit code: " .. return_val)
110-
end,
111-
on_stderr = function(j, data)
112-
print("Downloading, please wait: " .. data)
113-
if j ~= nil then
114-
logger.error(vim.inspect(j:result()))
115-
end
116-
end,
117-
})
118-
job:start()
119-
return true
93+
local names = {}
94+
if parsed_data.models then
95+
for _, model in ipairs(parsed_data.models) do
96+
table.insert(names, model.name)
12097
end
12198
else
122-
return true
99+
return { "No model found, please download" }
123100
end
124-
return false
125-
end
126-
127-
function Ollama:get_available_models()
128-
return {
129-
"llama3",
130-
}
101+
return names
131102
end
132103

133104
return Ollama

lua/parrot/provider/openai.lua

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,10 @@
11
local logger = require("parrot.logger")
22
local utils = require("parrot.utils")
3+
local Job = require("plenary.job")
34

45
local OpenAI = {}
56
OpenAI.__index = OpenAI
67

7-
local available_model_set = {
8-
["gpt-3.5-turbo"] = true,
9-
["gpt-3.5-turbo-0125"] = true,
10-
["gpt-3.5-turbo-1106"] = true,
11-
["gpt-3.5-turbo-16k"] = true,
12-
["gpt-3.5-turbo-instruct"] = true,
13-
["gpt-3.5-turbo-instruct-0914"] = true,
14-
["gpt-4"] = true,
15-
["gpt-4-0125-preview"] = true,
16-
["gpt-4-0613"] = true,
17-
["gpt-4-1106-preview"] = true,
18-
["gpt-4-turbo"] = true,
19-
["gpt-4-turbo-2024-04-09"] = true,
20-
["gpt-4-turbo-preview"] = true,
21-
["gpt-4o"] = true,
22-
["gpt-4o-2024-05-13"] = true,
23-
["gpt-4o-mini"] = true,
24-
["gpt-4o-mini-2024-07-18"] = true,
25-
}
26-
278
-- https://platform.openai.com/docs/api-reference/chat/create
289
local available_api_parameters = {
2910
-- required
@@ -110,32 +91,27 @@ function OpenAI:process_onexit(res)
11091
end
11192
end
11293

113-
function OpenAI:check(model)
114-
return available_model_set[model]
115-
end
116-
11794
function OpenAI:get_available_models()
11895
-- curl https://api.openai.com/v1/models -H "Authorization: Bearer $OPENAI_API_KEY"
119-
-- local Job = require("plenary.job")
120-
-- self:verify()
121-
-- Job:new({
122-
-- command = "curl",
123-
-- args = {
124-
-- "https://api.openai.com/v1/models",
125-
-- "-H", "Authorization: Bearer " .. self.api_key,
126-
-- },
127-
-- on_exit = function(job)
128-
-- local parsed_response = utils.parse_raw_response(job:result())
129-
-- print("JSON DEOCDE", vim.inspect(vim.json.decode(parsed_response)))
130-
-- local ids = {}
131-
-- for _, item in ipairs(vim.json.decode(parsed_response).data) do
132-
-- table.insert(ids, item.id)
133-
-- end
134-
-- print("IDS", vim.inspect(ids))
135-
-- return ids
136-
-- end,
137-
-- }):start()
13896

97+
if self:verify() then
98+
Job:new({
99+
command = "curl",
100+
args = {
101+
"https://api.openai.com/v1/models",
102+
"-H",
103+
"Authorization: Bearer " .. self.api_key,
104+
},
105+
on_exit = function(job)
106+
local parsed_response = utils.parse_raw_response(job:result())
107+
local ids = {}
108+
for _, item in ipairs(vim.json.decode(parsed_response).data) do
109+
table.insert(ids, item.id)
110+
end
111+
return ids
112+
end,
113+
}):start()
114+
end
139115
return {
140116
"gpt-3.5-turbo",
141117
"gpt-3.5-turbo-0125",

lua/parrot/provider/perplexity.lua

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,6 @@ local utils = require("parrot.utils")
44
local Perplexity = {}
55
Perplexity.__index = Perplexity
66

7-
local available_model_set = {
8-
-- adjust the new data
9-
["llama-3.1-sonar-small-128k-chat"] = true,
10-
["llama-3.1-sonar-small-128k-online"] = true,
11-
["llama-3.1-sonar-large-128k-chat"] = true,
12-
["llama-3.1-sonar-large-128k-online"] = true,
13-
["llama-3.1-8b-instruct"] = true,
14-
["llama-3.1-70b-instruct"] = true,
15-
["mixtral-8x7b-instruct"] = true,
16-
}
17-
187
-- https://docs.perplexity.ai/reference/post_chat_completions
198
local available_api_parameters = {
209
-- required
@@ -95,10 +84,6 @@ function Perplexity:process_onexit(res)
9584
end
9685
end
9786

98-
function Perplexity:check(model)
99-
return available_model_set[model]
100-
end
101-
10287
function Perplexity:get_available_models()
10388
return {
10489
"llama-3.1-sonar-small-128k-chat",

0 commit comments

Comments
 (0)