Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type TTSConfig struct {

// ModelConfig represents a model configuration
type ModelConfig struct {
modelConfigFile string `yaml:"-" json:"-"`
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
Name string `yaml:"name" json:"name"`

Expand Down Expand Up @@ -492,6 +493,10 @@ func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
}

func (c *ModelConfig) GetModelConfigFile() string {
return c.modelConfigFile
}

type ModelConfigUsecases int

const (
Expand Down
3 changes: 3 additions & 0 deletions core/config/model_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) (
}

for _, cc := range *c {
cc.modelConfigFile = file
cc.SetDefaults(opts...)
}

Expand All @@ -108,6 +109,8 @@ func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelCon
}

c.SetDefaults(opts...)

c.modelConfigFile = file
return c, nil
}

Expand Down
102 changes: 29 additions & 73 deletions core/http/endpoints/localai/edit_model.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package localai

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
Expand Down Expand Up @@ -37,21 +34,19 @@
return c.Status(404).JSON(response)
}

configData, err := yaml.Marshal(modelConfig)
if err != nil {
modelConfigFile := modelConfig.GetModelConfigFile()
if modelConfigFile == "" {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
Error: "Model configuration file not found",
}
return c.Status(500).JSON(response)
return c.Status(404).JSON(response)
}

// Marshal the config to JSON for the template
configJSON, err := json.Marshal(modelConfig)
configData, err := os.ReadFile(modelConfigFile)

Check failure

Code scanning / gosec

Potential file inclusion via variable Error

Potential file inclusion via variable
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modelConfigFile is not coming from the request, but it's where the config loader module found out the config

if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
Error: "Failed to read configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
}
Expand All @@ -69,7 +64,6 @@
Title: "LocalAI - Edit Model " + modelName,
ModelName: modelName,
Config: &modelConfig,
ConfigJSON: string(configJSON),
ConfigYAML: string(configData),
BaseURL: httpUtils.BaseURL(c),
Version: internal.PrintableVersion(),
Expand All @@ -91,6 +85,15 @@
return c.Status(400).JSON(response)
}

modelConfig, exists := cl.GetModelConfig(modelName)
if !exists {
response := ModelResponse{
Success: false,
Error: "Existing model configuration not found",
}
return c.Status(404).JSON(response)
}

// Get the raw body
body := c.Body()
if len(body) == 0 {
Expand All @@ -101,50 +104,16 @@
return c.Status(400).JSON(response)
}

// Check content type to determine how to parse
contentType := string(c.Context().Request.Header.ContentType())
// Check content to see if it's a valid model config
var req config.ModelConfig
var err error

if strings.Contains(contentType, "application/json") {
// Parse JSON
if err := json.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
if err := yaml.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Try to auto-detect format
if strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Assume YAML
if err := yaml.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
// Parse YAML
if err := yaml.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}

// Validate required fields
Expand All @@ -156,19 +125,6 @@
return c.Status(400).JSON(response)
}

// Load the existing configuration
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml")
if err := utils.VerifyPath(modelName+".yaml", appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Model configuration not trusted: " + err.Error(),
}
return c.Status(404).JSON(response)
}

// Set defaults
req.SetDefaults()

// Validate the configuration
if !req.Validate() {
response := ModelResponse{
Expand All @@ -179,18 +135,18 @@
return c.Status(400).JSON(response)
}

// Create the YAML file
yamlData, err := yaml.Marshal(req)
if err != nil {
// Load the existing configuration
configPath := modelConfig.GetModelConfigFile()
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
Error: "Model configuration not trusted: " + err.Error(),
}
return c.Status(500).JSON(response)
return c.Status(404).JSON(response)
}

// Write to file
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
// Write new content to file
if err := os.WriteFile(configPath, body, 0644); err != nil {

Check failure

Code scanning / gosec

Expect WriteFile permissions to be 0600 or less Error

Expect WriteFile permissions to be 0600 or less
response := ModelResponse{
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
Expand Down
Loading
Loading