Skip to content

Commit b44eb0d

Browse files
committed
refactor: Simplify code, support custom VERSION mapping configuration
1 parent a20da1e commit b44eb0d

File tree

4 files changed

+163
-153
lines changed

4 files changed

+163
-153
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
104104
- gpt-4-turbo-preview -> gemini-1.5-pro-latest
105105
- gpt-4-vision-preview -> gemini-1.0-pro-vision-latest
106106

107+
If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data.
108+
107109
These are the corresponding model mappings for your reference. We've aligned the models from our project with the latest offerings from Gemini, ensuring compatibility and seamless integration.
108110

109111
4. **Handle Responses:**

api/handler.go

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"io"
66
"log"
77
"net/http"
8-
"strings"
98

109
"github.com/gin-gonic/gin"
1110
"github.com/google/generative-ai-go/genai"
@@ -88,6 +87,15 @@ func ChatProxyHandler(c *gin.Context) {
8887
return
8988
}
9089

90+
content, err := req.ToGenaiContent()
91+
if err != nil {
92+
c.JSON(http.StatusBadRequest, openai.APIError{
93+
Code: http.StatusBadRequest,
94+
Message: err.Error(),
95+
})
96+
return
97+
}
98+
9199
ctx := c.Request.Context()
92100
client, err := genai.NewClient(ctx, option.WithAPIKey(openaiAPIKey))
93101
if err != nil {
@@ -100,20 +108,11 @@ func ChatProxyHandler(c *gin.Context) {
100108
}
101109
defer client.Close()
102110

103-
var gemini adapter.GenaiModelAdapter
104-
switch {
105-
case req.Model == openai.GPT4VisionPreview:
106-
gemini = adapter.NewGeminiProVisionAdapter(client)
107-
case req.Model == openai.GPT4TurboPreview || req.Model == openai.GPT4Turbo1106 || req.Model == openai.GPT4Turbo0125:
108-
gemini = adapter.NewGeminiProAdapter(client, adapter.Gemini1Dot5Pro)
109-
case strings.HasPrefix(req.Model, openai.GPT4):
110-
gemini = adapter.NewGeminiProAdapter(client, adapter.Gemini1Ultra)
111-
default:
112-
gemini = adapter.NewGeminiProAdapter(client, adapter.Gemini1Pro)
113-
}
111+
model := req.ToGenaiModel()
112+
gemini := adapter.NewGeminiAdapter(client, model)
114113

115114
if !req.Stream {
116-
resp, err := gemini.GenerateContent(ctx, req)
115+
resp, err := gemini.GenerateContent(ctx, req, content)
117116
if err != nil {
118117
log.Printf("genai generate content error %v\n", err)
119118
c.JSON(http.StatusBadRequest, openai.APIError{
@@ -127,7 +126,7 @@ func ChatProxyHandler(c *gin.Context) {
127126
return
128127
}
129128

130-
dataChan, err := gemini.GenerateStreamContent(ctx, req)
129+
dataChan, err := gemini.GenerateStreamContent(ctx, req, content)
131130
if err != nil {
132131
log.Printf("genai generate content error %v\n", err)
133132
c.JSON(http.StatusBadRequest, openai.APIError{

pkg/adapter/chat.go

Lines changed: 21 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -26,141 +26,58 @@ const (
2626
genaiRoleModel = "model"
2727
)
2828

29-
type GenaiModelAdapter interface {
30-
GenerateContent(ctx context.Context, req *ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
31-
GenerateStreamContent(ctx context.Context, req *ChatCompletionRequest) (<-chan string, error)
32-
}
33-
34-
type GeminiProAdapter struct {
29+
type GeminiAdapter struct {
3530
client *genai.Client
3631
model string
3732
}
3833

39-
func NewGeminiProAdapter(client *genai.Client, model string) GenaiModelAdapter {
40-
return &GeminiProAdapter{
34+
func NewGeminiAdapter(client *genai.Client, model string) *GeminiAdapter {
35+
return &GeminiAdapter{
4136
client: client,
4237
model: model,
4338
}
4439
}
4540

46-
func (g *GeminiProAdapter) GenerateContent(
41+
func (g *GeminiAdapter) GenerateContent(
4742
ctx context.Context,
4843
req *ChatCompletionRequest,
44+
content []*genai.Content,
4945
) (*openai.ChatCompletionResponse, error) {
5046
model := g.client.GenerativeModel(g.model)
5147
setGenaiModelByOpenaiRequest(model, req)
5248

5349
cs := model.StartChat()
54-
setGenaiChatByOpenaiRequest(cs, req)
50+
setGenaiChatHistory(cs, content)
5551

56-
prompt := genai.Text(req.Messages[len(req.Messages)-1].StringContent())
57-
genaiResp, err := cs.SendMessage(ctx, prompt)
52+
genaiResp, err := cs.SendMessage(ctx, content[len(content)-1].Parts...)
5853
if err != nil {
5954
return nil, errors.Wrap(err, "genai send message error")
6055
}
6156

62-
openaiResp := genaiResponseToOpenaiResponse(genaiResp, g.model)
57+
openaiResp := genaiResponseToOpenaiResponse(g.model, genaiResp)
6358
return &openaiResp, nil
6459
}
6560

66-
func (g *GeminiProAdapter) GenerateStreamContent(
61+
func (g *GeminiAdapter) GenerateStreamContent(
6762
ctx context.Context,
6863
req *ChatCompletionRequest,
64+
content []*genai.Content,
6965
) (<-chan string, error) {
7066
model := g.client.GenerativeModel(g.model)
7167
setGenaiModelByOpenaiRequest(model, req)
7268

7369
cs := model.StartChat()
74-
setGenaiChatByOpenaiRequest(cs, req)
70+
setGenaiChatHistory(cs, content)
7571

76-
prompt := genai.Text(req.Messages[len(req.Messages)-1].StringContent())
77-
iter := cs.SendMessageStream(ctx, prompt)
72+
iter := cs.SendMessageStream(ctx, content[len(content)-1].Parts...)
7873

7974
dataChan := make(chan string)
80-
go handleStreamIter(iter, dataChan, g.model)
75+
go handleStreamIter(g.model, iter, dataChan)
8176

8277
return dataChan, nil
8378
}
8479

85-
type GeminiProVisionAdapter struct {
86-
client *genai.Client
87-
}
88-
89-
func NewGeminiProVisionAdapter(client *genai.Client) GenaiModelAdapter {
90-
return &GeminiProVisionAdapter{
91-
client: client,
92-
}
93-
}
94-
95-
func (g *GeminiProVisionAdapter) GenerateContent(
96-
ctx context.Context,
97-
req *ChatCompletionRequest,
98-
) (*openai.ChatCompletionResponse, error) {
99-
model := g.client.GenerativeModel(Gemini1ProVision)
100-
setGenaiModelByOpenaiRequest(model, req)
101-
102-
// NOTE: use last message as prompt, gemini pro vision does not support context
103-
// https://ai.google.dev/tutorials/go_quickstart#multi-turn-conversations-chat
104-
prompt, err := g.openaiMessageToGenaiPrompt(req.Messages[len(req.Messages)-1])
105-
if err != nil {
106-
return nil, errors.Wrap(err, "genai generate prompt error")
107-
}
108-
109-
genaiResp, err := model.GenerateContent(ctx, prompt...)
110-
if err != nil {
111-
return nil, errors.Wrap(err, "genai send message error")
112-
}
113-
114-
openaiResp := genaiResponseToOpenaiResponse(genaiResp, Gemini1ProVision)
115-
return &openaiResp, nil
116-
}
117-
118-
func (*GeminiProVisionAdapter) openaiMessageToGenaiPrompt(msg ChatCompletionMessage) ([]genai.Part, error) {
119-
parts, err := msg.MultiContent()
120-
if err != nil {
121-
return nil, err
122-
}
123-
124-
prompt := make([]genai.Part, 0, len(parts))
125-
for _, part := range parts {
126-
switch part.Type {
127-
case openai.ChatMessagePartTypeText:
128-
prompt = append(prompt, genai.Text(part.Text))
129-
case openai.ChatMessagePartTypeImageURL:
130-
data, format, err := parseImageURL(part.ImageURL.URL)
131-
if err != nil {
132-
return nil, errors.Wrap(err, "parse image url error")
133-
}
134-
135-
prompt = append(prompt, genai.ImageData(format, data))
136-
}
137-
}
138-
return prompt, nil
139-
}
140-
141-
func (g *GeminiProVisionAdapter) GenerateStreamContent(
142-
ctx context.Context,
143-
req *ChatCompletionRequest,
144-
) (<-chan string, error) {
145-
model := g.client.GenerativeModel(Gemini1ProVision)
146-
setGenaiModelByOpenaiRequest(model, req)
147-
148-
// NOTE: use last message as prompt, gemini pro vision does not support context
149-
// https://ai.google.dev/tutorials/go_quickstart#multi-turn-conversations-chat
150-
prompt, err := g.openaiMessageToGenaiPrompt(req.Messages[len(req.Messages)-1])
151-
if err != nil {
152-
return nil, errors.Wrap(err, "genai generate prompt error")
153-
}
154-
155-
iter := model.GenerateContentStream(ctx, prompt...)
156-
157-
dataChan := make(chan string)
158-
go handleStreamIter(iter, dataChan, Gemini1ProVision)
159-
160-
return dataChan, nil
161-
}
162-
163-
func handleStreamIter(iter *genai.GenerateContentResponseIterator, dataChan chan string, model string) {
80+
func handleStreamIter(model string, iter *genai.GenerateContentResponseIterator, dataChan chan string) {
16481
defer close(dataChan)
16582

16683
respID := util.GetUUID()
@@ -184,7 +101,7 @@ func handleStreamIter(iter *genai.GenerateContentResponseIterator, dataChan chan
184101
break
185102
}
186103

187-
openaiResp := genaiResponseToStreamCompletionResponse(genaiResp, respID, created, model)
104+
openaiResp := genaiResponseToStreamCompletionResponse(model, genaiResp, respID, created)
188105
resp, _ := json.Marshal(openaiResp)
189106
dataChan <- string(resp)
190107

@@ -195,10 +112,10 @@ func handleStreamIter(iter *genai.GenerateContentResponseIterator, dataChan chan
195112
}
196113

197114
func genaiResponseToStreamCompletionResponse(
115+
model string,
198116
genaiResp *genai.GenerateContentResponse,
199117
respID string,
200118
created int64,
201-
model string,
202119
) *CompletionResponse {
203120
resp := CompletionResponse{
204121
ID: fmt.Sprintf("chatcmpl-%s", respID),
@@ -233,7 +150,7 @@ func genaiResponseToStreamCompletionResponse(
233150
}
234151

235152
func genaiResponseToOpenaiResponse(
236-
genaiResp *genai.GenerateContentResponse, model string,
153+
model string, genaiResp *genai.GenerateContentResponse,
237154
) openai.ChatCompletionResponse {
238155
resp := openai.ChatCompletionResponse{
239156
ID: fmt.Sprintf("chatcmpl-%s", util.GetUUID()),
@@ -275,42 +192,10 @@ func convertFinishReason(reason genai.FinishReason) openai.FinishReason {
275192
return openaiFinishReason
276193
}
277194

278-
func setGenaiChatByOpenaiRequest(cs *genai.ChatSession, req *ChatCompletionRequest) {
279-
cs.History = make([]*genai.Content, 0, len(req.Messages))
280-
if len(req.Messages) > 1 {
281-
for _, message := range req.Messages[:len(req.Messages)-1] {
282-
switch message.Role {
283-
case openai.ChatMessageRoleSystem:
284-
cs.History = append(cs.History, []*genai.Content{
285-
{
286-
Parts: []genai.Part{
287-
genai.Text(message.StringContent()),
288-
},
289-
Role: genaiRoleUser,
290-
},
291-
{
292-
Parts: []genai.Part{
293-
genai.Text(""),
294-
},
295-
Role: genaiRoleModel,
296-
},
297-
}...)
298-
case openai.ChatMessageRoleAssistant:
299-
cs.History = append(cs.History, &genai.Content{
300-
Parts: []genai.Part{
301-
genai.Text(message.StringContent()),
302-
},
303-
Role: genaiRoleModel,
304-
})
305-
case openai.ChatMessageRoleUser:
306-
cs.History = append(cs.History, &genai.Content{
307-
Parts: []genai.Part{
308-
genai.Text(message.StringContent()),
309-
},
310-
Role: genaiRoleUser,
311-
})
312-
}
313-
}
195+
func setGenaiChatHistory(cs *genai.ChatSession, content []*genai.Content) {
196+
cs.History = make([]*genai.Content, 0, len(content))
197+
if len(content) > 1 {
198+
cs.History = content[:len(content)-1]
314199
}
315200

316201
if len(cs.History) != 0 && cs.History[len(cs.History)-1].Role != genaiRoleModel {

0 commit comments

Comments
 (0)