diff --git a/.opencode.json b/.opencode.json index c4d1547a..308d1623 100644 --- a/.opencode.json +++ b/.opencode.json @@ -4,5 +4,12 @@ "gopls": { "command": "gopls" } + }, + "mcpServers": { + "spec-server": { + "command": "spec-server", + "args": ["stdio"], + "disabled": false + } } } diff --git a/README.md b/README.md index eee06acd..67beabff 100644 --- a/README.md +++ b/README.md @@ -524,6 +524,34 @@ OpenCode includes several built-in commands: OpenCode implements the Model Context Protocol (MCP) to extend its capabilities through external tools. MCP provides a standardized way for the AI assistant to interact with external services and tools. +### Spec Driven Development + +OpenCode supports Spec Driven Development through the `spec-server`, an MCP server that guides the user through a three-phase workflow: + +1. **Requirements:** Define user stories and acceptance criteria. +2. **Design:** Create a technical design document. +3. **Tasks:** Generate actionable implementation tasks. + +To use the `spec-server`, you first need to install it: + +```bash +pip install spec-server +``` + +Then, you need to add it to your `.opencode.json` configuration file: + +```json +{ + "mcpServers": { + "spec-server": { + "command": "spec-server", + "args": ["stdio"], + "disabled": false + } + } +} +``` + ### MCP Features - **External Tool Integration**: Connect to external tools and services via a standardized protocol @@ -626,11 +654,16 @@ This is useful for developers who want to experiment with custom models. ### Configuring a self-hosted provider -You can use a self-hosted model by setting the `LOCAL_ENDPOINT` environment variable. +You can use a self-hosted model by setting one of the following environment variables: + +- `OLLAMA_ENDPOINT`: For Ollama models +- `LMSTUDIO_ENDPOINT`: For LMStudio models +- `LOCAL_ENDPOINT`: For other local models + This will cause OpenCode to load and use the models from the specified endpoint. ```bash -LOCAL_ENDPOINT=http://localhost:1235/v1 +OLLAMA_ENDPOINT=http://localhost:11434 ``` ### Configuring a self-hosted model @@ -641,7 +674,7 @@ You can also configure a self-hosted model in the configuration file under the ` { "agents": { "coder": { - "model": "local.granite-3.3-2b-instruct@q8_0", + "model": "local/Ollama/llama2", "reasoningEffort": "high" } } diff --git a/cmd/config.go b/cmd/config.go new file mode 100644 index 00000000..54053b00 --- /dev/null +++ b/cmd/config.go @@ -0,0 +1,34 @@ +package cmd + +import ( + "fmt" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/spf13/cobra" +) + +var configCmd = &cobra.Command{ + Use: "config", + Short: "Configure opencode CLI", + RunE: func(cmd *cobra.Command, args []string) error { + return configure(cmd, args) + }, +} + +func configure(cmd *cobra.Command, args []string) error { + key, _ := cmd.Flags().GetString("key") + value, _ := cmd.Flags().GetString("value") + + if key == "" || value == "" { + return fmt.Errorf("key and value are required") + } + + return config.UpdateProviderAPIKey(models.ModelProvider(key), value) +} + +func init() { + configCmd.Flags().StringP("key", "k", "", "Configuration key") + configCmd.Flags().StringP("value", "v", "", "Configuration value") + rootCmd.AddCommand(configCmd) +} diff --git a/cmd/config_test.go b/cmd/config_test.go new file mode 100644 index 00000000..db91442f --- /dev/null +++ b/cmd/config_test.go @@ -0,0 +1,21 @@ +package cmd + +import ( + "testing" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/require" +) + +func TestConfigCmd(t *testing.T) { + _, err := config.Load(".", false) + require.NoError(t, err) + + rootCmd.SetArgs([]string{"config", "--key", "openai", "--value", "test-key"}) + err = rootCmd.Execute() + require.NoError(t, err) + + cfg := config.Get() + require.Equal(t, "test-key", cfg.Providers[models.ProviderOpenAI].APIKey) +} diff --git a/cmd/root.go b/cmd/root.go index 3a58cec4..58a81f09 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,6 +12,7 @@ import ( "github.com/opencode-ai/opencode/internal/app" "github.com/opencode-ai/opencode/internal/config" "github.com/opencode-ai/opencode/internal/db" + customerrors "github.com/opencode-ai/opencode/internal/errors" "github.com/opencode-ai/opencode/internal/format" "github.com/opencode-ai/opencode/internal/llm/agent" "github.com/opencode-ai/opencode/internal/logging" @@ -284,6 +285,13 @@ func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func Execute() { err := rootCmd.Execute() if err != nil { + // If the error is a configuration error, we want to print a more helpful message + if e, ok := err.(*customerrors.Error); ok { + if e.Code == customerrors.ErrNotFound { + fmt.Println("No valid provider available. Please configure a provider using 'opencode config'") + os.Exit(1) + } + } os.Exit(1) } } diff --git a/internal/config/config.go b/internal/config/config.go index 630fac9b..7f9bdf8f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,6 +10,7 @@ import ( "runtime" "strings" + customerrors "github.com/opencode-ai/opencode/internal/errors" "github.com/opencode-ai/opencode/internal/llm/models" "github.com/opencode-ai/opencode/internal/logging" "github.com/spf13/viper" @@ -277,6 +278,7 @@ func setProviderDefaults() { // api-key may be empty when using Entra ID credentials – that's okay viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) } + viper.SetDefault("providers.local.endpoint", "http://localhost:11434") if apiKey, err := LoadGitHubToken(); err == nil && apiKey != "" { viper.SetDefault("providers.copilot.apiKey", apiKey) if viper.GetString("providers.copilot.apiKey") == "" { @@ -487,7 +489,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error { if setDefaultModelForAgent(name) { logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) } else { - return fmt.Errorf("no valid provider available for agent %s", name) + return customerrors.Newf(customerrors.ErrNotFound, "no valid provider available for agent %s", name) } return nil } @@ -509,7 +511,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error { if setDefaultModelForAgent(name) { logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) } else { - return fmt.Errorf("no valid provider available for agent %s", name) + return customerrors.Newf(customerrors.ErrNotFound, "no valid provider available for agent %s", name) } } else { // Add provider with API key from environment @@ -529,7 +531,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error { if setDefaultModelForAgent(name) { logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) } else { - return fmt.Errorf("no valid provider available for agent %s", name) + return customerrors.Newf(customerrors.ErrNotFound, "no valid provider available for agent %s", name) } } @@ -929,6 +931,35 @@ func UpdateTheme(themeName string) error { }) } +func UpdateProviderAPIKey(provider models.ModelProvider, apiKey string) error { + if cfg == nil { + return fmt.Errorf("config not loaded") + } + + if cfg.Providers == nil { + cfg.Providers = make(map[models.ModelProvider]Provider) + } + + providerCfg, ok := cfg.Providers[provider] + if !ok { + providerCfg = Provider{} + } + providerCfg.APIKey = apiKey + cfg.Providers[provider] = providerCfg + + return updateCfgFile(func(config *Config) { + if config.Providers == nil { + config.Providers = make(map[models.ModelProvider]Provider) + } + providerCfg, ok := config.Providers[provider] + if !ok { + providerCfg = Provider{} + } + providerCfg.APIKey = apiKey + config.Providers[provider] = providerCfg + }) +} + // Tries to load Github token from all possible locations func LoadGitHubToken() (string, error) { // First check environment variable diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 00000000..b4b0209d --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,46 @@ +package errors + +import "fmt" + +// Error is a custom error type that contains a code and a message. +type Error struct { + Code int + Message string +} + +// Error returns the error message. +func (e *Error) Error() string { + return e.Message +} + +// New creates a new error. +func New(code int, message string) *Error { + return &Error{ + Code: code, + Message: message, + } +} + +// Newf creates a new error with a formatted message. +func Newf(code int, format string, a ...interface{}) *Error { + return &Error{ + Code: code, + Message: fmt.Sprintf(format, a...), + } +} + +// Error codes +const ( + // ErrUnknown is an unknown error. + ErrUnknown = iota + // ErrNotFound is a not found error. + ErrNotFound + // ErrForbidden is a forbidden error. + ErrForbidden + // ErrBadRequest is a bad request error. + ErrBadRequest + // ErrUnauthorized is an unauthorized error. + ErrUnauthorized + // ErrInternal is an internal error. + ErrInternal +) diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go new file mode 100644 index 00000000..e8460728 --- /dev/null +++ b/internal/errors/errors_test.go @@ -0,0 +1,17 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestErrors(t *testing.T) { + err := New(ErrNotFound, "not found") + require.Equal(t, "not found", err.Error()) + require.Equal(t, ErrNotFound, err.Code) + + err = Newf(ErrBadRequest, "bad request %d", 400) + require.Equal(t, "bad request 400", err.Error()) + require.Equal(t, ErrBadRequest, err.Code) +} diff --git a/internal/llm/models/cohere.go b/internal/llm/models/cohere.go new file mode 100644 index 00000000..0bf6dfa5 --- /dev/null +++ b/internal/llm/models/cohere.go @@ -0,0 +1,26 @@ +package models + +const ( + ProviderCohere ModelProvider = "cohere" +) + +const ( + CohereCommandRPlus ModelID = "cohere-command-r-plus" +) + +var CohereModels = map[ModelID]Model{ + CohereCommandRPlus: { + ID: CohereCommandRPlus, + Name: "Cohere: Command R+", + Provider: ProviderCohere, + APIModel: "command-r-plus", + CostPer1MIn: 0, + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: 128000, + DefaultMaxTokens: 4096, + CanReason: true, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/models/huggingface.go b/internal/llm/models/huggingface.go new file mode 100644 index 00000000..eb36f361 --- /dev/null +++ b/internal/llm/models/huggingface.go @@ -0,0 +1,26 @@ +package models + +const ( + ProviderHuggingFace ModelProvider = "huggingface" +) + +const ( + HuggingFaceMistral7BInstruct ModelID = "huggingface-mistral-7b-instruct" +) + +var HuggingFaceModels = map[ModelID]Model{ + HuggingFaceMistral7BInstruct: { + ID: HuggingFaceMistral7BInstruct, + Name: "Hugging Face: Mistral 7B Instruct", + Provider: ProviderHuggingFace, + APIModel: "mistralai/Mistral-7B-Instruct-v0.1", + CostPer1MIn: 0, + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CanReason: true, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/models/local.go b/internal/llm/models/local.go index db0ea11c..58dfebe1 100644 --- a/internal/llm/models/local.go +++ b/internal/llm/models/local.go @@ -16,44 +16,63 @@ import ( const ( ProviderLocal ModelProvider = "local" - - localModelsPath = "v1/models" - lmStudioBetaModelsPath = "api/v0/models" ) -func init() { - if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" { - localEndpoint, err := url.Parse(endpoint) - if err != nil { - logging.Debug("Failed to parse local endpoint", - "error", err, - "endpoint", endpoint, - ) - return - } +var localProviders = []struct { + name string + envVar string + modelsPath string + isBeta bool +}{ + { + name: "Ollama", + envVar: "OLLAMA_ENDPOINT", + modelsPath: "/api/tags", + }, + { + name: "LMStudio", + envVar: "LMSTUDIO_ENDPOINT", + modelsPath: "api/v0/models", + isBeta: true, + }, + { + name: "Local", + envVar: "LOCAL_ENDPOINT", + modelsPath: "v1/models", + }, +} - load := func(url *url.URL, path string) []localModel { - url.Path = path - return listLocalModels(url.String()) - } +func init() { + for _, provider := range localProviders { + if endpoint := os.Getenv(provider.envVar); endpoint != "" { + localEndpoint, err := url.Parse(endpoint) + if err != nil { + logging.Debug("Failed to parse local endpoint", + "error", err, + "endpoint", endpoint, + ) + continue + } - models := load(localEndpoint, lmStudioBetaModelsPath) + load := func(url *url.URL, path string) []localModel { + url.Path = path + return listLocalModels(url.String(), provider.isBeta) + } - if len(models) == 0 { - models = load(localEndpoint, localModelsPath) - } + models := load(localEndpoint, provider.modelsPath) - if len(models) == 0 { - logging.Debug("No local models found", - "endpoint", endpoint, - ) - return - } + if len(models) == 0 { + logging.Debug("No local models found", + "endpoint", endpoint, + ) + continue + } - loadLocalModels(models) + loadLocalModels(models, provider.name) - viper.SetDefault("providers.local.apiKey", "dummy") - ProviderPopularity[ProviderLocal] = 0 + viper.SetDefault("providers.local.apiKey", "dummy") + ProviderPopularity[ProviderLocal] = 0 + } } } @@ -74,7 +93,7 @@ type localModel struct { LoadedContextLength int64 `json:"loaded_context_length"` } -func listLocalModels(modelsEndpoint string) []localModel { +func listLocalModels(modelsEndpoint string, isBeta bool) []localModel { res, err := http.Get(modelsEndpoint) if err != nil { logging.Debug("Failed to list local models", @@ -104,9 +123,9 @@ func listLocalModels(modelsEndpoint string) []localModel { var supportedModels []localModel for _, model := range modelList.Data { - if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) { + if isBeta { if model.Object != "model" || model.Type != "llm" { - logging.Debug("Skipping unsupported LMStudio model", + logging.Debug("Skipping unsupported model", "endpoint", modelsEndpoint, "id", model.ID, "object", model.Object, @@ -123,9 +142,9 @@ func listLocalModels(modelsEndpoint string) []localModel { return supportedModels } -func loadLocalModels(models []localModel) { +func loadLocalModels(models []localModel, providerName string) { for i, m := range models { - model := convertLocalModel(m) + model := convertLocalModel(m, providerName) SupportedModels[model.ID] = model if i == 0 || m.State == "loaded" { @@ -137,9 +156,9 @@ func loadLocalModels(models []localModel) { } } -func convertLocalModel(model localModel) Model { +func convertLocalModel(model localModel, providerName string) Model { return Model{ - ID: ModelID("local." + model.ID), + ID: ModelID("local/" + providerName + "/" + model.ID), Name: friendlyModelName(model.ID), Provider: ProviderLocal, APIModel: model.ID, diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 2bcb508e..8c1d6ca2 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -36,15 +36,19 @@ const ( // Providers in order of popularity var ProviderPopularity = map[ModelProvider]int{ - ProviderCopilot: 1, - ProviderAnthropic: 2, - ProviderOpenAI: 3, - ProviderGemini: 4, - ProviderGROQ: 5, - ProviderOpenRouter: 6, - ProviderBedrock: 7, - ProviderAzure: 8, - ProviderVertexAI: 9, + ProviderOllama: 1, + ProviderOpenRouter: 2, + ProviderGemini: 3, + ProviderCopilot: 4, + ProviderAnthropic: 5, + ProviderOpenAI: 6, + ProviderGROQ: 7, + ProviderBedrock: 8, + ProviderAzure: 9, + ProviderVertexAI: 10, + ProviderHuggingFace: 11, + ProviderReplicate: 12, + ProviderCohere: 13, } var SupportedModels = map[ModelID]Model{ @@ -95,4 +99,8 @@ func init() { maps.Copy(SupportedModels, XAIModels) maps.Copy(SupportedModels, VertexAIGeminiModels) maps.Copy(SupportedModels, CopilotModels) + maps.Copy(SupportedModels, OllamaModels) + maps.Copy(SupportedModels, HuggingFaceModels) + maps.Copy(SupportedModels, ReplicateModels) + maps.Copy(SupportedModels, CohereModels) } diff --git a/internal/llm/models/ollama.go b/internal/llm/models/ollama.go new file mode 100644 index 00000000..3111eb1d --- /dev/null +++ b/internal/llm/models/ollama.go @@ -0,0 +1,26 @@ +package models + +const ( + ProviderOllama ModelProvider = "ollama" +) + +const ( + OllamaLlama3 ModelID = "ollama-llama3" +) + +var OllamaModels = map[ModelID]Model{ + OllamaLlama3: { + ID: OllamaLlama3, + Name: "Ollama: Llama 3", + Provider: ProviderOllama, + APIModel: "llama3", + CostPer1MIn: 0, + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: 8192, + DefaultMaxTokens: 4096, + CanReason: true, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/models/replicate.go b/internal/llm/models/replicate.go new file mode 100644 index 00000000..8149a1cf --- /dev/null +++ b/internal/llm/models/replicate.go @@ -0,0 +1,26 @@ +package models + +const ( + ProviderReplicate ModelProvider = "replicate" +) + +const ( + ReplicateLlama270BChat ModelID = "replicate-llama-2-70b-chat" +) + +var ReplicateModels = map[ModelID]Model{ + ReplicateLlama270BChat: { + ID: ReplicateLlama270BChat, + Name: "Replicate: Llama 2 70B Chat", + Provider: ProviderReplicate, + APIModel: "meta/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96e2e", + CostPer1MIn: 0, + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: 4096, + DefaultMaxTokens: 2048, + CanReason: true, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/provider/cohere.go b/internal/llm/provider/cohere.go new file mode 100644 index 00000000..05a5f483 --- /dev/null +++ b/internal/llm/provider/cohere.go @@ -0,0 +1,18 @@ +package provider + +import () + +type CohereClient ProviderClient + +func newCohereClient(opts providerClientOptions) OpenAIClient { + opts.openaiOptions = append(opts.openaiOptions, + WithOpenAIBaseURL("https://api.cohere.ai/v1"), + ) + return newOpenAIClient(opts) +} + +func WithCohereOptions(cohereOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = cohereOptions + } +} diff --git a/internal/llm/provider/cohere_test.go b/internal/llm/provider/cohere_test.go new file mode 100644 index 00000000..1d95b506 --- /dev/null +++ b/internal/llm/provider/cohere_test.go @@ -0,0 +1,26 @@ +package provider + +import ( + "context" + "os" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/require" +) + +func TestCohereProvider(t *testing.T) { + if os.Getenv("COHERE_API_KEY") == "" { + t.Skip("COHERE_API_KEY not set, skipping test") + } + + provider, err := NewProvider( + models.ProviderCohere, + WithModel(models.SupportedModels[models.CohereCommandRPlus]), + WithAPIKey(os.Getenv("COHERE_API_KEY")), + ) + require.NoError(t, err) + + _, err = provider.SendMessages(context.Background(), nil, nil) + require.NoError(t, err) +} diff --git a/internal/llm/provider/huggingface.go b/internal/llm/provider/huggingface.go new file mode 100644 index 00000000..ce1af829 --- /dev/null +++ b/internal/llm/provider/huggingface.go @@ -0,0 +1,18 @@ +package provider + +import () + +type HuggingFaceClient ProviderClient + +func newHuggingFaceClient(opts providerClientOptions) OpenAIClient { + opts.openaiOptions = append(opts.openaiOptions, + WithOpenAIBaseURL("https://api-inference.huggingface.co/v1"), + ) + return newOpenAIClient(opts) +} + +func WithHuggingFaceOptions(huggingfaceOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = huggingfaceOptions + } +} diff --git a/internal/llm/provider/huggingface_test.go b/internal/llm/provider/huggingface_test.go new file mode 100644 index 00000000..3057ddab --- /dev/null +++ b/internal/llm/provider/huggingface_test.go @@ -0,0 +1,26 @@ +package provider + +import ( + "context" + "os" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/require" +) + +func TestHuggingFaceProvider(t *testing.T) { + if os.Getenv("HUGGINGFACE_API_KEY") == "" { + t.Skip("HUGGINGFACE_API_KEY not set, skipping test") + } + + provider, err := NewProvider( + models.ProviderHuggingFace, + WithModel(models.SupportedModels[models.HuggingFaceMistral7BInstruct]), + WithAPIKey(os.Getenv("HUGGINGFACE_API_KEY")), + ) + require.NoError(t, err) + + _, err = provider.SendMessages(context.Background(), nil, nil) + require.NoError(t, err) +} diff --git a/internal/llm/provider/mock.go b/internal/llm/provider/mock.go new file mode 100644 index 00000000..b0b897e4 --- /dev/null +++ b/internal/llm/provider/mock.go @@ -0,0 +1,31 @@ +package provider + +import ( + "context" + + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" +) + +type MockClient struct { + SendMessagesFunc func(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + StreamResponseFunc func(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent +} + +func (m *MockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if m.SendMessagesFunc != nil { + return m.SendMessagesFunc(ctx, messages, tools) + } + return nil, nil +} + +func (m *MockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + if m.StreamResponseFunc != nil { + return m.StreamResponseFunc(ctx, messages, tools) + } + return nil +} + +func newMockClient(opts providerClientOptions) *MockClient { + return opts.mockClient +} diff --git a/internal/llm/provider/mock_test.go b/internal/llm/provider/mock_test.go new file mode 100644 index 00000000..3168a977 --- /dev/null +++ b/internal/llm/provider/mock_test.go @@ -0,0 +1,31 @@ +package provider + +import ( + "context" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/require" +) + +func TestMockProvider(t *testing.T) { + mockClient := &MockClient{ + SendMessagesFunc: func(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + return &ProviderResponse{ + Content: "Hello, world!", + }, nil + }, + } + + provider, err := NewProvider( + models.ProviderMock, + WithMockClient(mockClient), + ) + require.NoError(t, err) + + resp, err := provider.SendMessages(context.Background(), nil, nil) + require.NoError(t, err) + require.Equal(t, "Hello, world!", resp.Content) +} diff --git a/internal/llm/provider/ollama.go b/internal/llm/provider/ollama.go new file mode 100644 index 00000000..4efdec7c --- /dev/null +++ b/internal/llm/provider/ollama.go @@ -0,0 +1,26 @@ +package provider + +import ( + "os" +) + +type OllamaClient ProviderClient + +func newOllamaClient(opts providerClientOptions) OpenAIClient { + endpoint := os.Getenv("OLLAMA_ENDPOINT") + if endpoint == "" { + endpoint = "http://localhost:11434" + } + + opts.openaiOptions = append(opts.openaiOptions, + WithOpenAIBaseURL(endpoint), + ) + + return newOpenAIClient(opts) +} + +func WithOllamaOptions(ollamaOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = ollamaOptions + } +} diff --git a/internal/llm/provider/ollama_test.go b/internal/llm/provider/ollama_test.go new file mode 100644 index 00000000..57161a60 --- /dev/null +++ b/internal/llm/provider/ollama_test.go @@ -0,0 +1,25 @@ +package provider + +import ( + "context" + "os" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/require" +) + +func TestOllamaProvider(t *testing.T) { + if os.Getenv("OLLAMA_ENDPOINT") == "" { + t.Skip("OLLAMA_ENDPOINT not set, skipping test") + } + + provider, err := NewProvider( + models.ProviderOllama, + WithModel(models.SupportedModels[models.OllamaLlama3]), + ) + require.NoError(t, err) + + _, err = provider.SendMessages(context.Background(), nil, nil) + require.NoError(t, err) +} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 8a561c77..aed2e477 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -214,7 +214,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too continue } } - return nil, retryErr + return nil, err } content := "" diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index d5be0ba0..b54d8b54 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -69,6 +69,7 @@ type providerClientOptions struct { geminiOptions []GeminiOption bedrockOptions []BedrockOption copilotOptions []CopilotOption + mockClient *MockClient } type ProviderClientOption func(*providerClientOptions) @@ -152,17 +153,45 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption options: clientOptions, client: newOpenAIClient(clientOptions), }, nil + case models.ProviderOllama: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOllamaClient(clientOptions), + }, nil + case models.ProviderHuggingFace: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newHuggingFaceClient(clientOptions), + }, nil + case models.ProviderReplicate: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newReplicateClient(clientOptions), + }, nil + case models.ProviderCohere: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newCohereClient(clientOptions), + }, nil case models.ProviderLocal: + endpoint := os.Getenv("LOCAL_ENDPOINT") + if endpoint == "" { + endpoint = os.Getenv("OLLAMA_ENDPOINT") + } + if endpoint == "" { + endpoint = os.Getenv("LMSTUDIO_ENDPOINT") + } clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")), + WithOpenAIBaseURL(endpoint), ) return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderMock: - // TODO: implement mock client for test - panic("not implemented") + return &baseProvider[ProviderClient]{ + client: clientOptions.mockClient, + }, nil } return nil, fmt.Errorf("provider not supported: %s", providerName) } @@ -245,3 +274,9 @@ func WithCopilotOptions(copilotOptions ...CopilotOption) ProviderClientOption { options.copilotOptions = copilotOptions } } + +func WithMockClient(client *MockClient) ProviderClientOption { + return func(options *providerClientOptions) { + options.mockClient = client + } +} diff --git a/internal/llm/provider/replicate.go b/internal/llm/provider/replicate.go new file mode 100644 index 00000000..a601df6f --- /dev/null +++ b/internal/llm/provider/replicate.go @@ -0,0 +1,18 @@ +package provider + +import () + +type ReplicateClient ProviderClient + +func newReplicateClient(opts providerClientOptions) OpenAIClient { + opts.openaiOptions = append(opts.openaiOptions, + WithOpenAIBaseURL("https://api.replicate.com/v1"), + ) + return newOpenAIClient(opts) +} + +func WithReplicateOptions(replicateOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = replicateOptions + } +} diff --git a/internal/llm/provider/replicate_test.go b/internal/llm/provider/replicate_test.go new file mode 100644 index 00000000..cb316667 --- /dev/null +++ b/internal/llm/provider/replicate_test.go @@ -0,0 +1,26 @@ +package provider + +import ( + "context" + "os" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/require" +) + +func TestReplicateProvider(t *testing.T) { + if os.Getenv("REPLICATE_API_KEY") == "" { + t.Skip("REPLICATE_API_KEY not set, skipping test") + } + + provider, err := NewProvider( + models.ProviderReplicate, + WithModel(models.SupportedModels[models.ReplicateLlama270BChat]), + WithAPIKey(os.Getenv("REPLICATE_API_KEY")), + ) + require.NoError(t, err) + + _, err = provider.SendMessages(context.Background(), nil, nil) + require.NoError(t, err) +} diff --git a/internal/llm/tools/ls_test.go b/internal/llm/tools/ls_test.go index 508cb98d..e4703b8b 100644 --- a/internal/llm/tools/ls_test.go +++ b/internal/llm/tools/ls_test.go @@ -8,8 +8,8 @@ import ( "strings" "testing" + "github.com/opencode-ai/opencode/internal/config" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestLsTool_Info(t *testing.T) { @@ -24,9 +24,10 @@ func TestLsTool_Info(t *testing.T) { } func TestLsTool_Run(t *testing.T) { + config.Load(".", false) // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "ls_tool_test") - require.NoError(t, err) + assert.NoError(t, err) defer os.RemoveAll(tempDir) // Create a test directory structure @@ -57,14 +58,14 @@ func TestLsTool_Run(t *testing.T) { for _, dir := range testDirs { dirPath := filepath.Join(tempDir, dir) err := os.MkdirAll(dirPath, 0755) - require.NoError(t, err) + assert.NoError(t, err) } // Create files for _, file := range testFiles { filePath := filepath.Join(tempDir, file) err := os.WriteFile(filePath, []byte("test content"), 0644) - require.NoError(t, err) + assert.NoError(t, err) } t.Run("lists directory successfully", func(t *testing.T) { @@ -74,7 +75,7 @@ func TestLsTool_Run(t *testing.T) { } paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + assert.NoError(t, err) call := ToolCall{ Name: LSToolName, @@ -82,7 +83,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) // Check that visible directories and files are included assert.Contains(t, response.Content, "dir1") @@ -107,7 +108,7 @@ func TestLsTool_Run(t *testing.T) { } paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + assert.NoError(t, err) call := ToolCall{ Name: LSToolName, @@ -115,7 +116,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) assert.Contains(t, response.Content, "path does not exist") }) @@ -129,7 +130,7 @@ func TestLsTool_Run(t *testing.T) { } paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + assert.NoError(t, err) call := ToolCall{ Name: LSToolName, @@ -137,7 +138,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) // The response should either contain a valid directory listing or an error // We'll just check that it's not empty @@ -152,7 +153,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) assert.Contains(t, response.Content, "error parsing parameters") }) @@ -164,7 +165,7 @@ func TestLsTool_Run(t *testing.T) { } paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + assert.NoError(t, err) call := ToolCall{ Name: LSToolName, @@ -172,7 +173,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) // The output format is a tree, so we need to check for specific patterns // Check that file1.txt is not directly mentioned @@ -185,7 +186,7 @@ func TestLsTool_Run(t *testing.T) { t.Run("handles relative path", func(t *testing.T) { // Save original working directory origWd, err := os.Getwd() - require.NoError(t, err) + assert.NoError(t, err) defer func() { os.Chdir(origWd) }() @@ -193,7 +194,7 @@ func TestLsTool_Run(t *testing.T) { // Change to a directory above the temp directory parentDir := filepath.Dir(tempDir) err = os.Chdir(parentDir) - require.NoError(t, err) + assert.NoError(t, err) tool := NewLsTool() params := LSParams{ @@ -201,7 +202,7 @@ func TestLsTool_Run(t *testing.T) { } paramsJSON, err := json.Marshal(params) - require.NoError(t, err) + assert.NoError(t, err) call := ToolCall{ Name: LSToolName, @@ -209,7 +210,7 @@ func TestLsTool_Run(t *testing.T) { } response, err := tool.Run(context.Background(), call) - require.NoError(t, err) + assert.NoError(t, err) // Should list the temp directory contents assert.Contains(t, response.Content, "dir1") @@ -316,7 +317,7 @@ func TestCreateFileTree(t *testing.T) { } } - require.NotNil(t, dir1Node) + assert.NotNil(t, dir1Node) assert.Equal(t, "directory", dir1Node.Type) assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir } @@ -369,7 +370,7 @@ func TestPrintTree(t *testing.T) { func TestListDirectory(t *testing.T) { // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "list_directory_test") - require.NoError(t, err) + assert.NoError(t, err) defer os.RemoveAll(tempDir) // Create a test directory structure @@ -391,19 +392,19 @@ func TestListDirectory(t *testing.T) { for _, dir := range testDirs { dirPath := filepath.Join(tempDir, dir) err := os.MkdirAll(dirPath, 0755) - require.NoError(t, err) + assert.NoError(t, err) } // Create files for _, file := range testFiles { filePath := filepath.Join(tempDir, file) err := os.WriteFile(filePath, []byte("test content"), 0644) - require.NoError(t, err) + assert.NoError(t, err) } t.Run("lists files with no limit", func(t *testing.T) { files, truncated, err := listDirectory(tempDir, []string{}, 1000) - require.NoError(t, err) + assert.NoError(t, err) assert.False(t, truncated) // Check that visible files and directories are included @@ -429,14 +430,14 @@ func TestListDirectory(t *testing.T) { t.Run("respects limit and returns truncated flag", func(t *testing.T) { files, truncated, err := listDirectory(tempDir, []string{}, 2) - require.NoError(t, err) + assert.NoError(t, err) assert.True(t, truncated) assert.Len(t, files, 2) }) t.Run("respects ignore patterns", func(t *testing.T) { files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000) - require.NoError(t, err) + assert.NoError(t, err) assert.False(t, truncated) // Check that no .txt files are included diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1c9c2f03..de7d8893 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -21,6 +21,11 @@ import ( "github.com/opencode-ai/opencode/internal/tui/layout" "github.com/opencode-ai/opencode/internal/tui/page" "github.com/opencode-ai/opencode/internal/tui/theme" + "os" + "path/filepath" + + "os/exec" + "github.com/opencode-ai/opencode/internal/tui/util" ) @@ -442,6 +447,16 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return a, nil + case tea.Msg: + switch msg { + case "setup-agent-os": + cmd := exec.Command("bash", "-c", "curl -fsSL https://raw.githubusercontent.com/buildermethods/agent-os/main/install | bash") + err := cmd.Run() + if err != nil { + return a, util.ReportError(fmt.Errorf("failed to install Agent OS: %w", err)) + } + return a, util.ReportInfo("Agent OS installed successfully") + } case tea.KeyMsg: // If multi-arguments dialog is open, let it handle the key press first if a.showMultiArgumentsDialog { @@ -677,6 +692,34 @@ func (a *appModel) findCommand(id string) (dialog.Command, bool) { return dialog.Command{}, false } +func createAgentOsCommands() tea.Cmd { + return func() tea.Msg { + wd, err := os.Getwd() + if err != nil { + return util.InfoMsg{Type: util.InfoTypeError, Msg: fmt.Sprintf("failed to get working directory: %v", err)} + } + + commandsDir := filepath.Join(wd, ".opencode", "commands") + if _, err := os.Stat(commandsDir); os.IsNotExist(err) { + if err := os.MkdirAll(commandsDir, 0755); err != nil { + return util.InfoMsg{Type: util.InfoTypeError, Msg: fmt.Sprintf("failed to create commands directory: %v", err)} + } + } + + agentOsCommands := []string{"analyze_product", "create_spec", "execute_tasks", "plan_product"} + for _, cmd := range agentOsCommands { + cmdFile := filepath.Join(commandsDir, cmd+".md") + if _, err := os.Stat(cmdFile); os.IsNotExist(err) { + content := fmt.Sprintf("@~/agent_os/instructions/%s.md", cmd) + if err := os.WriteFile(cmdFile, []byte(content), 0644); err != nil { + return util.InfoMsg{Type: util.InfoTypeError, Msg: fmt.Sprintf("failed to create command file: %v", err)} + } + } + } + return util.InfoMsg{Type: util.InfoTypeInfo, Msg: "Agent OS project initialized successfully"} + } +} + func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { if a.app.CoderAgent.IsBusy() { // For now we don't move to any page if the agent is busy @@ -926,6 +969,28 @@ func New(app *app.App) tea.Model { Title: "Initialize Project", Description: "Create/Update the OpenCode.md memory file", Handler: func(cmd dialog.Command) tea.Cmd { + // Check for global Agent OS installation + homeDir, err := os.UserHomeDir() + if err != nil { + return util.ReportError(fmt.Errorf("failed to get home directory: %w", err)) + } + agentOsDir := filepath.Join(homeDir, "agent_os") + if _, err := os.Stat(agentOsDir); os.IsNotExist(err) { + return util.ReportWarn("Agent OS not found. Please run the 'Setup Agent OS' command from the command menu (ctrl+k).") + } + + // Create project-specific agent_os directory + wd, err := os.Getwd() + if err != nil { + return util.ReportError(fmt.Errorf("failed to get working directory: %w", err)) + } + projectAgentOsDir := filepath.Join(wd, "agent_os") + if _, err := os.Stat(projectAgentOsDir); os.IsNotExist(err) { + if err := os.MkdirAll(projectAgentOsDir, 0755); err != nil { + return util.ReportError(fmt.Errorf("failed to create agent_os directory: %w", err)) + } + } + prompt := `Please analyze this codebase and create a OpenCode.md file containing: 1. Build/lint/test commands - especially for running a single test 2. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc. @@ -933,10 +998,18 @@ func New(app *app.App) tea.Model { The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long. If there's already a opencode.md, improve it. If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.` + + // Generate project planning documents + planningPrompt := "Plan out a new product based on the current project context, referencing the Agent OS instructions and standards." + return tea.Batch( util.CmdHandler(chat.SendMsg{ Text: prompt, }), + util.CmdHandler(chat.SendMsg{ + Text: planningPrompt, + }), + createAgentOsCommands(), ) }, }) @@ -951,6 +1024,15 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules ( } }, }) + + model.RegisterCommand(dialog.Command{ + ID: "setup-agent-os", + Title: "Setup Agent OS", + Description: "Install Agent OS", + Handler: func(cmd dialog.Command) tea.Cmd { + return util.CmdHandler(tea.Msg("setup-agent-os")) + }, + }) // Load custom commands customCommands, err := dialog.LoadCustomCommands() if err != nil { diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go new file mode 100644 index 00000000..669bf706 --- /dev/null +++ b/internal/tui/tui_test.go @@ -0,0 +1,50 @@ +package tui + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/opencode-ai/opencode/internal/tui/util" + "github.com/stretchr/testify/assert" +) + +func TestCreateAgentOsCommands(t *testing.T) { + // Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "agent_os_commands_test") + assert.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Change the current working directory to the temporary directory + err = os.Chdir(tempDir) + assert.NoError(t, err) + + // Run the createAgentOsCommands function + cmd := createAgentOsCommands() + msg := cmd() + + // Check that the message is not an error + assert.NotNil(t, msg) + infoMsg, ok := msg.(util.InfoMsg) + assert.True(t, ok) + assert.Equal(t, util.InfoTypeInfo, infoMsg.Type) + + // Check that the commands directory was created + commandsDir := filepath.Join(tempDir, ".opencode", "commands") + _, err = os.Stat(commandsDir) + assert.NoError(t, err) + + // Check that the command files were created + agentOsCommands := []string{"analyze_product", "create_spec", "execute_tasks", "plan_product"} + for _, cmd := range agentOsCommands { + cmdFile := filepath.Join(commandsDir, cmd+".md") + _, err = os.Stat(cmdFile) + assert.NoError(t, err) + + // Check the content of the command file + content, err := os.ReadFile(cmdFile) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("@~/agent_os/instructions/%s.md", cmd), string(content)) + } +}