Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/llm_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,22 @@ impl LLMProvider {
)
.await
}
LLMProviderInterface::LMStudio(lmstudio) => {
lmstudio
.call_api(
&self.client,
self.external_url.as_ref(),
self.api_key.as_ref(),
prompt.clone(),
self.model.clone(),
inbox_name,
ws_manager_trait,
merged_config,
llm_stopper,
self.db.clone(),
)
.await
}
LLMProviderInterface::Claude(claude) => {
claude
.call_api(
Expand Down
135 changes: 135 additions & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/lm_studio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::sync::Arc;

use super::super::error::LLMProviderError;
use super::LLMService;
use super::openai::{
add_options_to_payload, handle_non_streaming_response, handle_streaming_response,
truncate_image_url_in_payload,
};
use crate::llm_provider::execution::chains::inference_chain_trait::LLMInferenceResponse;
use crate::llm_provider::llm_stopper::LLMStopper;
use crate::managers::model_capabilities_manager::{ModelCapabilitiesManager, PromptResultEnum};
use async_trait::async_trait;
use reqwest::Client;
use serde_json::json;
use serde_json::{self};
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job_config::JobConfig;
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{
LLMProviderInterface, LMStudio,
};
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::ws_types::WSUpdateHandler;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{
shinkai_log, ShinkaiLogLevel, ShinkaiLogOption,
};
use shinkai_sqlite::SqliteManager;
use tokio::sync::Mutex;
use uuid::Uuid;

use super::shared::openai_api::openai_prepare_messages;

#[async_trait]
impl LLMService for LMStudio {
async fn call_api(
&self,
client: &Client,
url: Option<&String>,
api_key: Option<&String>,
prompt: Prompt,
model: LLMProviderInterface,
inbox_name: Option<InboxName>,
ws_manager_trait: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
config: Option<JobConfig>,
llm_stopper: Arc<LLMStopper>,
_db: Arc<SqliteManager>,
) -> Result<LLMInferenceResponse, LLMProviderError> {
let session_id = Uuid::new_v4().to_string();
if let Some(base_url) = url {
let url = format!("{}{}", base_url, "/api/v0/chat/completions");

let is_stream = config.as_ref().and_then(|c| c.stream).unwrap_or(true);

let result = openai_prepare_messages(&model, prompt)?;
let messages_json = match result.messages {
PromptResultEnum::Value(v) => v,
_ => {
return Err(LLMProviderError::UnexpectedPromptResultVariant(
"Expected Value variant in PromptResultEnum".to_string(),
))
}
};

let tools_json = result.functions.unwrap_or_else(Vec::new);

let mut payload = if ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
json!({
"model": self.model_type,
"messages": messages_json,
"max_completion_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
} else {
json!({
"model": self.model_type,
"messages": messages_json,
"max_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
};

if !tools_json.is_empty() {
payload["tools"] = serde_json::Value::Array(tools_json.clone());
}

if !ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
add_options_to_payload(&mut payload, config.as_ref());
}

match serde_json::to_string_pretty(&payload) {
Ok(pretty_json) => eprintln!("cURL Payload: {}", pretty_json),
Err(e) => eprintln!("Failed to serialize payload: {:?}", e),
};

let mut payload_log = payload.clone();
truncate_image_url_in_payload(&mut payload_log);
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
format!("Call API Body: {:?}", payload_log).as_str(),
);

if is_stream {
handle_streaming_response(
client,
url,
payload,
api_key.unwrap_or(&"".to_string()).to_string(),
inbox_name,
ws_manager_trait,
llm_stopper,
session_id,
Some(tools_json),
None,
)
.await
} else {
handle_non_streaming_response(
client,
url,
payload,
api_key.unwrap_or(&"".to_string()).to_string(),
inbox_name,
llm_stopper,
ws_manager_trait,
Some(tools_json),
None,
)
.await
}
} else {
Err(LLMProviderError::UrlNotSet)
}
}
}

1 change: 1 addition & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod ollama;
pub mod openai;
pub mod openai_tests;
pub mod openrouter;
pub mod lm_studio;
pub mod shared;
pub mod shinkai_backend;
pub mod togetherai;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl ModelCapabilitiesManager {
LLMProviderInterface::Groq(model) => Self::get_shared_capabilities(model.model_type().as_str()),
LLMProviderInterface::Gemini(_) => vec![ModelCapability::TextInference, ModelCapability::ImageAnalysis],
LLMProviderInterface::OpenRouter(model) => Self::get_shared_capabilities(model.model_type().as_str()),
LLMProviderInterface::LMStudio(model) => Self::get_shared_capabilities(model.model_type().as_str()),
LLMProviderInterface::Claude(_) => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
LLMProviderInterface::DeepSeek(_) => vec![ModelCapability::TextInference],
LLMProviderInterface::LocalRegex(_) => vec![ModelCapability::ImageAnalysis, ModelCapability::TextInference],
Expand Down Expand Up @@ -228,6 +229,7 @@ impl ModelCapabilitiesManager {
LLMProviderInterface::Gemini(_) => ModelCost::Cheap,
LLMProviderInterface::Exo(_) => ModelCost::Cheap,
LLMProviderInterface::OpenRouter(_) => ModelCost::Free,
LLMProviderInterface::LMStudio(_) => ModelCost::Free,
LLMProviderInterface::Claude(claude) => match claude.model_type.as_str() {
"claude-3-5-sonnet-20241022" | "claude-3-5-sonnet-latest" => ModelCost::Cheap,
"claude-sonnet-4-20250514" | "claude-sonnet-4-latest" => ModelCost::Cheap,
Expand Down Expand Up @@ -264,6 +266,7 @@ impl ModelCapabilitiesManager {
LLMProviderInterface::Gemini(_) => ModelPrivacy::RemoteGreedy,
LLMProviderInterface::Exo(_) => ModelPrivacy::Local,
LLMProviderInterface::OpenRouter(_) => ModelPrivacy::Local,
LLMProviderInterface::LMStudio(_) => ModelPrivacy::Local,
LLMProviderInterface::Claude(_) => ModelPrivacy::RemoteGreedy,
LLMProviderInterface::DeepSeek(_) => ModelPrivacy::RemoteGreedy,
LLMProviderInterface::LocalRegex(_) => ModelPrivacy::Local,
Expand Down Expand Up @@ -349,6 +352,18 @@ impl ModelCapabilitiesManager {
Ok(messages_string)
}
}
LLMProviderInterface::LMStudio(lmstudio) => {
if Self::get_shared_capabilities(lmstudio.model_type.as_str()).is_empty() {
Err(ModelCapabilitiesManagerError::NotImplemented(
lmstudio.model_type.clone(),
))
} else {
let total_tokens = Self::get_max_tokens(model);
let messages_string =
llama_prepare_messages(model, lmstudio.clone().model_type, prompt, total_tokens)?;
Ok(messages_string)
}
}
LLMProviderInterface::Groq(groq) => {
let total_tokens = Self::get_max_tokens(model);
let messages_string = llama_prepare_messages(model, groq.clone().model_type, prompt, total_tokens)?;
Expand Down Expand Up @@ -435,6 +450,7 @@ impl ModelCapabilitiesManager {
std::cmp::min(Self::get_max_tokens_for_model_type(&groq.model_type), 7000)
}
LLMProviderInterface::OpenRouter(openrouter) => Self::get_max_tokens_for_model_type(&openrouter.model_type),
LLMProviderInterface::LMStudio(lmstudio) => Self::get_max_tokens_for_model_type(&lmstudio.model_type),
LLMProviderInterface::Claude(_) => 200_000,
LLMProviderInterface::DeepSeek(_) => 64_000,
LLMProviderInterface::LocalRegex(_) => 128_000,
Expand Down Expand Up @@ -590,6 +606,13 @@ impl ModelCapabilitiesManager {
4096
}
}
LLMProviderInterface::LMStudio(_) => {
if Self::get_max_tokens(model) <= 8000 {
2800
} else {
4096
}
}
LLMProviderInterface::Claude(claude) => {
if claude.model_type.starts_with("claude-opus-4") {
32_000
Expand Down Expand Up @@ -759,6 +782,16 @@ impl ModelCapabilitiesManager {
|| model.model_type.starts_with("mistral-large")
|| model.model_type.starts_with("mistral-pixtral")
}
LLMProviderInterface::LMStudio(model) => {
model.model_type.starts_with("llama-3.2")
|| model.model_type.starts_with("llama3.2")
|| model.model_type.starts_with("llama-3.1")
|| model.model_type.starts_with("llama3.1")
|| model.model_type.starts_with("mistral-nemo")
|| model.model_type.starts_with("mistral-small")
|| model.model_type.starts_with("mistral-large")
|| model.model_type.starts_with("mistral-pixtral")
}
LLMProviderInterface::Claude(claude) => {
claude.model_type.starts_with("claude-sonnet-4")
|| claude.model_type.starts_with("claude-opus-4")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ impl SerializedLLMProvider {
LLMProviderInterface::Gemini(_) => "gemini",
LLMProviderInterface::Exo(_) => "exo",
LLMProviderInterface::OpenRouter(_) => "openrouter",
LLMProviderInterface::LMStudio(_) => "lm_studio",
LLMProviderInterface::Claude(_) => "claude",
LLMProviderInterface::DeepSeek(_) => "deepseek",
LLMProviderInterface::LocalRegex(_) => "local-regex",
Expand All @@ -45,6 +46,7 @@ impl SerializedLLMProvider {
LLMProviderInterface::Gemini(_) => "google-ai".to_string(),
LLMProviderInterface::Exo(_) => "openai-generic".to_string(),
LLMProviderInterface::OpenRouter(_) => "openai-generic".to_string(),
LLMProviderInterface::LMStudio(_) => "openai-generic".to_string(),
LLMProviderInterface::Claude(_) => "claude".to_string(),
LLMProviderInterface::DeepSeek(_) => "openai-generic".to_string(),
LLMProviderInterface::LocalRegex(_) => "local-regex".to_string(),
Expand All @@ -61,6 +63,7 @@ impl SerializedLLMProvider {
LLMProviderInterface::Gemini(gemini) => gemini.model_type.clone(),
LLMProviderInterface::Exo(exo) => exo.model_type.clone(),
LLMProviderInterface::OpenRouter(openrouter) => openrouter.model_type.clone(),
LLMProviderInterface::LMStudio(lmstudio) => lmstudio.model_type.clone(),
LLMProviderInterface::Claude(claude) => claude.model_type.clone(),
LLMProviderInterface::DeepSeek(deepseek) => deepseek.model_type.clone(),
LLMProviderInterface::LocalRegex(local_regex) => local_regex.model_type.clone(),
Expand Down Expand Up @@ -122,6 +125,7 @@ pub enum LLMProviderInterface {
Gemini(Gemini),
Exo(Exo),
OpenRouter(OpenRouter),
LMStudio(LMStudio),
Claude(Claude),
DeepSeek(DeepSeek),
LocalRegex(LocalRegex),
Expand Down Expand Up @@ -208,6 +212,17 @@ impl OpenRouter {
}
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema)]
pub struct LMStudio {
pub model_type: String,
}

impl LMStudio {
pub fn model_type(&self) -> String {
self.model_type.to_string()
}
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct TogetherAI {
pub model_type: String,
Expand Down Expand Up @@ -263,6 +278,9 @@ impl FromStr for LLMProviderInterface {
} else if s.starts_with("openrouter:") {
let model_type = s.strip_prefix("openrouter:").unwrap_or("").to_string();
Ok(LLMProviderInterface::OpenRouter(OpenRouter { model_type }))
} else if s.starts_with("lm_studio:") {
let model_type = s.strip_prefix("lm_studio:").unwrap_or("").to_string();
Ok(LLMProviderInterface::LMStudio(LMStudio { model_type }))
} else if s.starts_with("claude:") {
let model_type = s.strip_prefix("claude:").unwrap_or("").to_string();
Ok(LLMProviderInterface::Claude(Claude { model_type }))
Expand Down Expand Up @@ -316,6 +334,10 @@ impl Serialize for LLMProviderInterface {
let model_type = format!("openrouter:{}", openrouter.model_type);
serializer.serialize_str(&model_type)
}
LLMProviderInterface::LMStudio(lmstudio) => {
let model_type = format!("lm_studio:{}", lmstudio.model_type);
serializer.serialize_str(&model_type)
}
LLMProviderInterface::Claude(claude) => {
let model_type = format!("claude:{}", claude.model_type);
serializer.serialize_str(&model_type)
Expand Down Expand Up @@ -371,6 +393,9 @@ impl<'de> Visitor<'de> for LLMProviderInterfaceVisitor {
"openrouter" => Ok(LLMProviderInterface::OpenRouter(OpenRouter {
model_type: parts.get(1).unwrap_or(&"").to_string(),
})),
"lm_studio" => Ok(LLMProviderInterface::LMStudio(LMStudio {
model_type: parts.get(1).unwrap_or(&"").to_string(),
})),
"claude" => Ok(LLMProviderInterface::Claude(Claude {
model_type: parts.get(1).unwrap_or(&"").to_string(),
})),
Expand All @@ -392,6 +417,7 @@ impl<'de> Visitor<'de> for LLMProviderInterfaceVisitor {
"exo",
"gemini",
"openrouter",
"lm_studio",
"claude",
"deepseek",
"local-regex",
Expand Down
Loading