Skip to content

Commit 0621a05

Browse files
committed
feat: support function call response
1 parent ea1b43d commit 0621a05

File tree

6 files changed

+87
-81
lines changed

6 files changed

+87
-81
lines changed

api/handler.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func IndexHandler(c *gin.Context) {
2424

2525
func ModelListHandler(c *gin.Context) {
2626
owner := adapter.GetOwner()
27-
27+
2828
// Get authorization header to initialize models if needed
2929
authorizationHeader := c.GetHeader("Authorization")
3030
var apiKey string
@@ -44,7 +44,7 @@ func ModelListHandler(c *gin.Context) {
4444
// When model mapping is disabled, return the actual Gemini models
4545
models := adapter.GetAvailableGeminiModels()
4646
modelList := make([]any, 0, len(models))
47-
47+
4848
for _, modelName := range models {
4949
modelList = append(modelList, openai.Model{
5050
CreatedAt: 1686935002,
@@ -53,7 +53,7 @@ func ModelListHandler(c *gin.Context) {
5353
OwnedBy: owner,
5454
})
5555
}
56-
56+
5757
c.JSON(http.StatusOK, gin.H{
5858
"object": "list",
5959
"data": modelList,
@@ -127,7 +127,7 @@ func ChatProxyHandler(c *gin.Context) {
127127
handleGenerateContentError(c, err)
128128
return
129129
}
130-
130+
131131
// Initialize Gemini models if not already initialized
132132
if err := adapter.InitGeminiModels(openaiAPIKey); err != nil {
133133
log.Printf("Error initializing Gemini models: %v", err)

pkg/adapter/chat.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ func handleStreamIter(model string, iter *genai.GenerateContentResponseIterator,
241241
break
242242
}
243243
}
244-
245244
} else {
246245
// For subsequent chunks after sentenceLength, send the entire text at once
247246
sendFullText(text)
@@ -295,7 +294,12 @@ func handleStreamIter(model string, iter *genai.GenerateContentResponseIterator,
295294
}
296295
}
297296

298-
func genaiResponseToStreamCompletionResponse(model string, genaiResp *genai.GenerateContentResponse, respID string, created int64) *CompletionResponse {
297+
func genaiResponseToStreamCompletionResponse(
298+
model string,
299+
genaiResp *genai.GenerateContentResponse,
300+
respID string,
301+
created int64,
302+
) *CompletionResponse {
299303
resp := CompletionResponse{
300304
ID: fmt.Sprintf("chatcmpl-%s", respID),
301305
Object: "chat.completion.chunk",
@@ -356,7 +360,10 @@ func genaiResponseToStreamCompletionResponse(model string, genaiResp *genai.Gene
356360
return &resp
357361
}
358362

359-
func genaiResponseToOpenaiResponse(model string, genaiResp *genai.GenerateContentResponse) openai.ChatCompletionResponse {
363+
func genaiResponseToOpenaiResponse(
364+
model string,
365+
genaiResp *genai.GenerateContentResponse,
366+
) openai.ChatCompletionResponse {
360367
resp := openai.ChatCompletionResponse{
361368
ID: fmt.Sprintf("chatcmpl-%s", util.GetUUID()),
362369
Object: "chat.completion",

pkg/adapter/models.go

+24-16
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ const (
2222
)
2323

2424
// GeminiModels stores the available models from Gemini API
25-
var GeminiModels []string
26-
var geminiModelsOnce sync.Once
27-
var geminiModelsLock sync.RWMutex
25+
var (
26+
GeminiModels []string
27+
geminiModelsOnce sync.Once
28+
geminiModelsLock sync.RWMutex
29+
)
2830

2931
var USE_MODEL_MAPPING bool = os.Getenv("DISABLE_MODEL_MAPPING") != "1"
3032

@@ -65,7 +67,13 @@ func InitGeminiModels(apiKey string) error {
6567
log.Printf("Failed to fetch Gemini models: %v\n", err)
6668
// Fallback to default models
6769
geminiModelsLock.Lock()
68-
GeminiModels = []string{Gemini1Dot5Pro, Gemini1Dot5Flash, Gemini1Dot5ProV, Gemini2FlashExp, TextEmbedding004}
70+
GeminiModels = []string{
71+
Gemini1Dot5Pro,
72+
Gemini1Dot5Flash,
73+
Gemini1Dot5ProV,
74+
Gemini2FlashExp,
75+
TextEmbedding004,
76+
}
6977
geminiModelsLock.Unlock()
7078
initErr = err
7179
return
@@ -82,11 +90,11 @@ func InitGeminiModels(apiKey string) error {
8290
func GetAvailableGeminiModels() []string {
8391
geminiModelsLock.RLock()
8492
defer geminiModelsLock.RUnlock()
85-
93+
8694
if len(GeminiModels) == 0 {
8795
return []string{Gemini1Dot5Pro, Gemini1Dot5Flash, Gemini1Dot5ProV, Gemini2FlashExp, TextEmbedding004}
8896
}
89-
97+
9098
return GeminiModels
9199
}
92100

@@ -110,22 +118,22 @@ func GetModel(openAiModelName string) string {
110118
func IsValidGeminiModel(modelName string) bool {
111119
if len(GeminiModels) == 0 {
112120
// If models haven't been fetched yet, use the default list
113-
return modelName == Gemini1Dot5Pro ||
114-
modelName == Gemini1Dot5Flash ||
115-
modelName == Gemini1Dot5ProV ||
116-
modelName == Gemini2FlashExp ||
117-
modelName == TextEmbedding004
121+
return modelName == Gemini1Dot5Pro ||
122+
modelName == Gemini1Dot5Flash ||
123+
modelName == Gemini1Dot5ProV ||
124+
modelName == Gemini2FlashExp ||
125+
modelName == TextEmbedding004
118126
}
119-
127+
120128
geminiModelsLock.RLock()
121129
defer geminiModelsLock.RUnlock()
122-
130+
123131
for _, model := range GeminiModels {
124132
if model == modelName {
125133
return true
126134
}
127135
}
128-
136+
129137
return false
130138
}
131139

@@ -185,7 +193,7 @@ func (req *ChatCompletionRequest) ParseModelWithoutMapping() string {
185193
if IsValidGeminiModel(req.Model) {
186194
return req.Model
187195
}
188-
196+
189197
// Fallback to default model if not valid
190198
log.Printf("Invalid model: %s, falling back to %s\n", req.Model, Gemini1Dot5Flash)
191199
return Gemini1Dot5Flash
@@ -213,7 +221,7 @@ func (req *EmbeddingRequest) ToGenaiModel() string {
213221
if IsValidGeminiModel(req.Model) {
214222
return req.Model
215223
}
216-
224+
217225
// Fallback to default embedding model if not valid
218226
log.Printf("Invalid embedding model: %s, falling back to %s\n", req.Model, TextEmbedding004)
219227
return TextEmbedding004

pkg/adapter/struct.go

+41-9
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ package adapter
22

33
import (
44
"encoding/json"
5+
"strings"
56

67
"github.com/google/generative-ai-go/genai"
78
"github.com/pkg/errors"
89
openai "github.com/sashabaranov/go-openai"
910
)
1011

1112
type ChatCompletionMessage struct {
12-
Role string `json:"role"`
13-
Content json.RawMessage `json:"content"`
13+
Role string `json:"role"`
14+
Content json.RawMessage `json:"content"`
15+
ToolCalls []openai.ToolCall `json:"tool_calls,omitempty"`
16+
ToolCallID string `json:"tool_call_id,omitempty"`
1417
}
1518

1619
// ChatCompletionRequest represents a request structure for chat completion API.
@@ -53,17 +56,34 @@ func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, erro
5356
if err := json.Unmarshal(message.Content, &singleString); err != nil {
5457
return nil, errors.Wrap(err, "failed to unmarshal message content")
5558
}
56-
// Convert single string to a part
57-
parts = []openai.ChatMessagePart{
58-
{Type: openai.ChatMessagePartTypeText, Text: singleString},
59+
60+
if len(message.ToolCalls) == 0 {
61+
// Convert single string to a part
62+
parts = []openai.ChatMessagePart{
63+
{Type: openai.ChatMessagePartTypeText, Text: singleString},
64+
}
5965
}
6066
}
6167

6268
prompt := make([]genai.Part, 0, len(parts))
6369
for _, part := range parts {
6470
switch part.Type {
6571
case openai.ChatMessagePartTypeText:
66-
prompt = append(prompt, genai.Text(part.Text))
72+
if message.Role == openai.ChatMessageRoleTool {
73+
functionName := message.ToolCallID
74+
lastDashIndex := strings.LastIndex(functionName, "-")
75+
if lastDashIndex != -1 {
76+
functionName = functionName[:lastDashIndex]
77+
}
78+
79+
prompt = append(prompt, genai.FunctionResponse{
80+
Name: functionName,
81+
Response: map[string]any{"result": part.Text},
82+
})
83+
} else {
84+
prompt = append(prompt, genai.Text(part.Text))
85+
}
86+
6787
case openai.ChatMessagePartTypeImageURL:
6888
data, format, err := parseImageURL(part.ImageURL.URL)
6989
if err != nil {
@@ -74,6 +94,18 @@ func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, erro
7494
}
7595
}
7696

97+
for _, tool := range message.ToolCalls {
98+
args := map[string]any{}
99+
if err := json.Unmarshal([]byte(tool.Function.Arguments), &args); err != nil {
100+
return nil, errors.Wrap(err, "failed to unmarshal message function args")
101+
}
102+
103+
prompt = append(prompt, genai.FunctionCall{
104+
Name: tool.Function.Name,
105+
Args: args,
106+
})
107+
}
108+
77109
switch message.Role {
78110
case openai.ChatMessageRoleSystem:
79111
content = append(content, []*genai.Content{
@@ -93,7 +125,7 @@ func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, erro
93125
Parts: prompt,
94126
Role: genaiRoleModel,
95127
})
96-
case openai.ChatMessageRoleUser:
128+
case openai.ChatMessageRoleUser, openai.ChatMessageRoleTool:
97129
content = append(content, &genai.Content{
98130
Parts: prompt,
99131
Role: genaiRoleUser,
@@ -106,8 +138,8 @@ func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, erro
106138
type CompletionChoice struct {
107139
Index int `json:"index"`
108140
Delta struct {
109-
Content string `json:"content,omitempty"`
110-
Role string `json:"role,omitempty"`
141+
Content string `json:"content,omitempty"`
142+
Role string `json:"role,omitempty"`
111143
ToolCalls []openai.ToolCall `json:"tool_calls,omitempty"`
112144
} `json:"delta"`
113145
FinishReason *string `json:"finish_reason"`

pkg/adapter/tools.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ func convertOpenAIToolsToGenAI(tools []openai.Tool) []*genai.Tool {
2525
if err != nil {
2626
continue // Skip this tool if we can't convert parameters
2727
}
28-
28+
2929
var convertedParams map[string]interface{}
3030
if err := json.Unmarshal(paramsBytes, &convertedParams); err != nil {
3131
continue // Skip this tool if we can't convert parameters
3232
}
3333
paramsMap = convertedParams
3434
}
35-
35+
3636
schema := convertJSONSchemaToGenAISchema(paramsMap)
3737

3838
item := &genai.Tool{
@@ -154,4 +154,4 @@ func convertJSONTypeToGenAIType(t string) genai.Type {
154154
default:
155155
return genai.TypeUnspecified
156156
}
157-
}
157+
}

0 commit comments

Comments
 (0)