Skip to content

Commit a9dafe6

Browse files
willvelidayaron2
andauthored
Adding support for custom endpoint for OpenAI Conversation Component (dapr#3834)
Signed-off-by: Will Velida <willvelida@hotmail.co.uk> Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
1 parent 4a50840 commit a9dafe6

File tree

5 files changed

+168
-3
lines changed

5 files changed

+168
-3
lines changed

conversation/metadata.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ type LangchainMetadata struct {
2626
Key string `json:"key"`
2727
Model string `json:"model"`
2828
CacheTTL string `json:"cacheTTL"`
29+
Endpoint string `json:"endpoint"`
2930
}

conversation/metadata_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
Copyright 2024 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package conversation
15+
16+
import (
17+
"encoding/json"
18+
"testing"
19+
20+
"github.com/stretchr/testify/assert"
21+
"github.com/stretchr/testify/require"
22+
)
23+
24+
func TestLangchainMetadata(t *testing.T) {
25+
t.Run("json marshaling with endpoint", func(t *testing.T) {
26+
metadata := LangchainMetadata{
27+
Key: "test-key",
28+
Model: "gpt-4",
29+
CacheTTL: "10m",
30+
Endpoint: "https://custom-endpoint.example.com",
31+
}
32+
33+
bytes, err := json.Marshal(metadata)
34+
require.NoError(t, err)
35+
36+
var unmarshaled LangchainMetadata
37+
err = json.Unmarshal(bytes, &unmarshaled)
38+
require.NoError(t, err)
39+
40+
assert.Equal(t, metadata.Key, unmarshaled.Key)
41+
assert.Equal(t, metadata.Model, unmarshaled.Model)
42+
assert.Equal(t, metadata.CacheTTL, unmarshaled.CacheTTL)
43+
assert.Equal(t, metadata.Endpoint, unmarshaled.Endpoint)
44+
})
45+
46+
t.Run("json unmarshaling with endpoint", func(t *testing.T) {
47+
jsonStr := `{"key": "test-key", "endpoint": "https://api.openai.com/v1"}`
48+
49+
var metadata LangchainMetadata
50+
err := json.Unmarshal([]byte(jsonStr), &metadata)
51+
require.NoError(t, err)
52+
53+
assert.Equal(t, "test-key", metadata.Key)
54+
assert.Equal(t, "https://api.openai.com/v1", metadata.Endpoint)
55+
})
56+
}

conversation/openai/metadata.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ metadata:
2727
The OpenAI LLM to use. Defaults to gpt-4o
2828
type: string
2929
example: 'gpt-4-turbo'
30+
- name: endpoint
31+
required: false
32+
description: |
33+
Custom API endpoint URL for OpenAI API-compatible services. If not specified, the default OpenAI API endpoint will be used.
34+
type: string
35+
example: 'https://api.openai.com/v1'
3036
- name: cacheTTL
3137
required: false
3238
description: |

conversation/openai/openai.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,18 @@ func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
5454
if md.Model != "" {
5555
model = md.Model
5656
}
57-
58-
llm, err := openai.New(
57+
// Create options for OpenAI client
58+
options := []openai.Option{
5959
openai.WithModel(model),
6060
openai.WithToken(md.Key),
61-
)
61+
}
62+
63+
// Add custom endpoint if provided
64+
if md.Endpoint != "" {
65+
options = append(options, openai.WithBaseURL(md.Endpoint))
66+
}
67+
68+
llm, err := openai.New(options...)
6269
if err != nil {
6370
return err
6471
}

conversation/openai/openai_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
Copyright 2024 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package openai
15+
16+
import (
17+
"testing"
18+
19+
"github.com/dapr/components-contrib/conversation"
20+
"github.com/dapr/components-contrib/metadata"
21+
"github.com/dapr/kit/logger"
22+
23+
"github.com/stretchr/testify/assert"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestInit(t *testing.T) {
28+
testCases := []struct {
29+
name string
30+
metadata map[string]string
31+
testFn func(*testing.T, *OpenAI, error)
32+
}{
33+
{
34+
name: "with default endpoint",
35+
metadata: map[string]string{
36+
"key": "test-key",
37+
"model": "gpt-4",
38+
},
39+
testFn: func(t *testing.T, o *OpenAI, err error) {
40+
require.NoError(t, err)
41+
assert.NotNil(t, o.llm)
42+
},
43+
},
44+
{
45+
name: "with custom endpoint",
46+
metadata: map[string]string{
47+
"key": "test-key",
48+
"model": "gpt-4",
49+
"endpoint": "https://api.openai.com/v1",
50+
},
51+
testFn: func(t *testing.T, o *OpenAI, err error) {
52+
require.NoError(t, err)
53+
assert.NotNil(t, o.llm)
54+
// Since we can't directly access the client's baseURL,
55+
// we're mainly testing that initialization succeeds
56+
},
57+
},
58+
}
59+
60+
for _, tc := range testCases {
61+
t.Run(tc.name, func(t *testing.T) {
62+
o := NewOpenAI(logger.NewLogger("openai test"))
63+
err := o.Init(t.Context(), conversation.Metadata{
64+
Base: metadata.Base{
65+
Properties: tc.metadata,
66+
},
67+
})
68+
tc.testFn(t, o.(*OpenAI), err)
69+
})
70+
}
71+
}
72+
73+
func TestEndpointInMetadata(t *testing.T) {
74+
// Create an instance of OpenAI component
75+
o := &OpenAI{}
76+
77+
// This test relies on the metadata tag
78+
md := o.GetComponentMetadata()
79+
if len(md) == 0 {
80+
t.Skip("Metadata is not enabled, skipping test")
81+
}
82+
83+
// Print all available metadata keys for debugging
84+
t.Logf("Available metadata keys: %v", func() []string {
85+
keys := make([]string, 0, len(md))
86+
for k := range md {
87+
keys = append(keys, k)
88+
}
89+
return keys
90+
}())
91+
92+
// Verify endpoint field exists (note: field names are capitalized in metadata)
93+
_, exists := md["Endpoint"]
94+
assert.True(t, exists, "Endpoint field should exist in metadata")
95+
}

0 commit comments

Comments
 (0)