Skip to content

Commit fddc1e2

Browse files
authored
Add Config to Disable Model Mapping (#41)
Introduce a new environment variable, DISABLE_MODEL_MAPPING, when enabled removes the OpenAI -> Gemini Model mapping and provides access to the named gemini models directly. Moved model mapping logic to its own `models.go` file inside the adapter package. Additionally, fixed a bug where responses would return the gemini model name even though model mapping was enabled.
1 parent da78e75 commit fddc1e2

File tree

5 files changed

+164
-52
lines changed

5 files changed

+164
-52
lines changed

README.md

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,25 @@ go build -o gemini main.go
3030

3131
We recommend deploying Gemini-OpenAI-Proxy using Docker for a straightforward setup. Follow these steps to deploy with Docker:
3232

33-
```bash
34-
docker run --restart=always -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest
35-
```
33+
You can either do this on the command line:
34+
```bash
35+
docker run --restart=unless-stopped -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest
36+
```
37+
38+
Or with the following docker-compose config:
39+
```yaml
40+
version: '3'
41+
services:
42+
gemini:
43+
container_name: gemini
44+
environment: # Set Environment Variables here. Defaults listed below
45+
- GPT_4_VISION_PREVIEW=gemini-1.5-flash-latest
46+
- DISABLE_MODEL_MAPPING=0
47+
ports:
48+
- "8080:8080"
49+
image: zhu327/gemini-openai-proxy:latest
50+
restart: unless-stopped
51+
```
3652
3753
Adjust the port mapping (e.g., `-p 8080:8080`) as needed, and ensure that the Docker image version (`zhu327/gemini-openai-proxy:latest`) aligns with your requirements.
3854

@@ -83,6 +99,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
8399
"temperature": 0.7
84100
}'
85101
```
102+
If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data. Otherwise, the default uses the `gemini-1.5-flash-latest` model
86103

87104
If you already have access to the Gemini 1.5 Pro api, you can use:
88105

@@ -104,7 +121,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
104121
-H "Content-Type: application/json" \
105122
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
106123
-d '{
107-
"model": "ada-002",
124+
"model": "text-embedding-ada-002",
108125
"input": "This is a test sentence."
109126
}'
110127
```
@@ -116,7 +133,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
116133
-H "Content-Type: application/json" \
117134
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
118135
-d '{
119-
"model": "ada-002",
136+
"model": "text-embedding-ada-002",
120137
"input": ["This is a test sentence.", "This is another test sentence"]
121138
}'
122139
```
@@ -129,9 +146,21 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
129146
| gpt-4 | gemini-1.5-flash-latest |
130147
| gpt-4-turbo-preview | gemini-1.5-pro-latest |
131148
| gpt-4-vision-preview | gemini-1.0-pro-vision-latest |
132-
| ada-002 | text-embedding-004 |
149+
| text-embedding-ada-002 | text-embedding-004 |
133150

134-
If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data.
151+
If you want to disable model mapping, configure the environment variable `DISABLE_MODEL_MAPPING=1`. This will allow you to refer to the Gemini models directly.
152+
153+
Here is an example API request with model mapping disabled:
154+
```bash
155+
curl http://localhost:8080/v1/chat/completions \
156+
-H "Content-Type: application/json" \
157+
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
158+
-d '{
159+
"model": "gemini-1.0-pro-latest",
160+
"messages": [{"role": "user", "content": "Say this is a test!"}],
161+
"temperature": 0.7
162+
}'
163+
```
135164

136165
4. **Handle Responses:**
137166
Process the responses from the Gemini-OpenAI-Proxy in the same way you would handle responses from OpenAI.

api/handler.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,52 @@ func IndexHandler(c *gin.Context) {
2121
}
2222

2323
func ModelListHandler(c *gin.Context) {
24+
owner := adapter.GetOwner()
2425
c.JSON(http.StatusOK, gin.H{
2526
"object": "list",
2627
"data": []any{
2728
openai.Model{
2829
CreatedAt: 1686935002,
29-
ID: openai.GPT3Dot5Turbo,
30+
ID: adapter.GetModel(openai.GPT3Dot5Turbo),
3031
Object: "model",
31-
OwnedBy: "openai",
32+
OwnedBy: owner,
3233
},
3334
openai.Model{
3435
CreatedAt: 1686935002,
35-
ID: openai.GPT4,
36+
ID: adapter.GetModel(openai.GPT4),
3637
Object: "model",
37-
OwnedBy: "openai",
38+
OwnedBy: owner,
3839
},
3940
openai.Model{
4041
CreatedAt: 1686935002,
41-
ID: openai.GPT4TurboPreview,
42+
ID: adapter.GetModel(openai.GPT4TurboPreview),
4243
Object: "model",
43-
OwnedBy: "openai",
44+
OwnedBy: owner,
4445
},
4546
openai.Model{
4647
CreatedAt: 1686935002,
47-
ID: openai.GPT4VisionPreview,
48+
ID: adapter.GetModel(openai.GPT4VisionPreview),
4849
Object: "model",
49-
OwnedBy: "openai",
50+
OwnedBy: owner,
5051
},
5152
openai.Model{
5253
CreatedAt: 1686935002,
53-
ID: openai.GPT3Ada002,
54+
ID: adapter.GetModel(string(openai.AdaEmbeddingV2)),
5455
Object: "model",
55-
OwnedBy: "openai",
56+
OwnedBy: owner,
5657
},
5758
},
5859
})
5960
}
6061

6162
func ModelRetrieveHandler(c *gin.Context) {
6263
model := c.Param("model")
64+
owner := adapter.GetOwner()
6365
c.JSON(http.StatusOK, openai.Model{
6466
CreatedAt: 1686935002,
6567
ID: model,
6668
Object: "model",
67-
OwnedBy: "openai",
69+
OwnedBy: owner,
6870
})
6971
}
7072

pkg/adapter/chat.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ import (
1717
)
1818

1919
const (
20-
Gemini1Pro = "gemini-1.0-pro-latest"
21-
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
22-
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
23-
TextEmbedding004 = "text-embedding-004"
24-
2520
genaiRoleUser = "user"
2621
genaiRoleModel = "model"
2722
)
@@ -121,7 +116,7 @@ func genaiResponseToStreamCompletionResponse(
121116
ID: fmt.Sprintf("chatcmpl-%s", respID),
122117
Object: "chat.completion.chunk",
123118
Created: created,
124-
Model: model,
119+
Model: GetMappedModel(model),
125120
Choices: make([]CompletionChoice, 0, len(genaiResp.Candidates)),
126121
}
127122

@@ -156,7 +151,7 @@ func genaiResponseToOpenaiResponse(
156151
ID: fmt.Sprintf("chatcmpl-%s", util.GetUUID()),
157152
Object: "chat.completion",
158153
Created: time.Now().Unix(),
159-
Model: model,
154+
Model: GetMappedModel(model),
160155
Choices: make([]openai.ChatCompletionChoice, 0, len(genaiResp.Candidates)),
161156
}
162157

@@ -260,7 +255,7 @@ func (g *GeminiAdapter) GenerateEmbedding(
260255
openaiResp := openai.EmbeddingResponse{
261256
Object: "list",
262257
Data: make([]openai.Embedding, 0, len(genaiResp.Embeddings)),
263-
Model: openai.EmbeddingModel(g.model),
258+
Model: openai.EmbeddingModel(GetMappedModel(g.model)),
264259
}
265260

266261
for i, genaiEmbedding := range genaiResp.Embeddings {

pkg/adapter/models.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package adapter
2+
3+
import (
4+
"os"
5+
"strings"
6+
7+
openai "github.com/sashabaranov/go-openai"
8+
)
9+
10+
const (
11+
Gemini1Pro = "gemini-1.0-pro-latest"
12+
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
13+
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
14+
Gemini1Dot5ProV = "gemini-1.0-pro-vision-latest" // Converted to one of the above models in struct::ToGenaiModel
15+
TextEmbedding004 = "text-embedding-004"
16+
)
17+
18+
var USE_MODEL_MAPPING bool = os.Getenv("DISABLE_MODEL_MAPPING") != "1"
19+
20+
func GetOwner() string {
21+
if USE_MODEL_MAPPING {
22+
return "openai"
23+
} else {
24+
return "google"
25+
}
26+
}
27+
28+
func GetModel(openAiModelName string) string {
29+
if USE_MODEL_MAPPING {
30+
return openAiModelName
31+
} else {
32+
return ConvertModel(openAiModelName)
33+
}
34+
}
35+
36+
func GetMappedModel(geminiModelName string) string {
37+
if !USE_MODEL_MAPPING {
38+
return geminiModelName
39+
}
40+
switch {
41+
case geminiModelName == Gemini1Dot5ProV:
42+
return openai.GPT4VisionPreview
43+
case geminiModelName == Gemini1Dot5Pro:
44+
return openai.GPT4TurboPreview
45+
case geminiModelName == Gemini1Dot5Flash:
46+
return openai.GPT4
47+
case geminiModelName == TextEmbedding004:
48+
return string(openai.AdaEmbeddingV2)
49+
default:
50+
return openai.GPT3Dot5Turbo
51+
}
52+
}
53+
54+
func ConvertModel(openAiModelName string) string {
55+
switch {
56+
case openAiModelName == openai.GPT4VisionPreview:
57+
return Gemini1Dot5ProV
58+
case openAiModelName == openai.GPT4TurboPreview || openAiModelName == openai.GPT4Turbo1106 || openAiModelName == openai.GPT4Turbo0125:
59+
return Gemini1Dot5Pro
60+
case strings.HasPrefix(openAiModelName, openai.GPT4):
61+
return Gemini1Dot5Flash
62+
case openAiModelName == string(openai.AdaEmbeddingV2):
63+
return TextEmbedding004
64+
default:
65+
return Gemini1Pro
66+
}
67+
}
68+
69+
func (req *ChatCompletionRequest) ToGenaiModel() string {
70+
if USE_MODEL_MAPPING {
71+
return req.ParseModelWithMapping()
72+
} else {
73+
return req.ParseModelWithoutMapping()
74+
}
75+
}
76+
77+
func (req *ChatCompletionRequest) ParseModelWithoutMapping() string {
78+
switch {
79+
case req.Model == Gemini1Dot5ProV:
80+
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
81+
return Gemini1Dot5Pro
82+
}
83+
84+
return Gemini1Dot5Flash
85+
default:
86+
return req.Model
87+
}
88+
}
89+
90+
func (req *ChatCompletionRequest) ParseModelWithMapping() string {
91+
switch {
92+
case req.Model == openai.GPT4VisionPreview:
93+
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
94+
return Gemini1Dot5Pro
95+
}
96+
97+
return Gemini1Dot5Flash
98+
default:
99+
return ConvertModel(req.Model)
100+
}
101+
}
102+
103+
func (req *EmbeddingRequest) ToGenaiModel() string {
104+
if USE_MODEL_MAPPING {
105+
return ConvertModel(req.Model)
106+
} else {
107+
return req.Model
108+
}
109+
}

pkg/adapter/struct.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package adapter
22

33
import (
44
"encoding/json"
5-
"os"
6-
"strings"
75

86
"github.com/google/generative-ai-go/genai"
97
"github.com/pkg/errors"
@@ -43,27 +41,10 @@ type ChatCompletionRequest struct {
4341
Stop []string `json:"stop,omitempty"`
4442
}
4543

46-
func (req *ChatCompletionRequest) ToGenaiModel() string {
47-
switch {
48-
case req.Model == openai.GPT4VisionPreview:
49-
if os.Getenv("GPT_4_VISION_PREVIEW") == Gemini1Dot5Pro {
50-
return Gemini1Dot5Pro
51-
}
52-
53-
return Gemini1Dot5Flash
54-
case req.Model == openai.GPT4TurboPreview || req.Model == openai.GPT4Turbo1106 || req.Model == openai.GPT4Turbo0125:
55-
return Gemini1Dot5Pro
56-
case strings.HasPrefix(req.Model, openai.GPT4):
57-
return Gemini1Dot5Flash
58-
default:
59-
return Gemini1Pro
60-
}
61-
}
62-
6344
func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) {
64-
if req.Model == openai.GPT4VisionPreview {
45+
if req.Model == Gemini1Dot5ProV || req.Model == openai.GPT4VisionPreview {
6546
return req.toVisionGenaiContent()
66-
} else if req.Model == openai.GPT3Ada002 {
47+
} else if req.Model == TextEmbedding004 || req.Model == string(openai.AdaEmbeddingV2) {
6748
return nil, errors.New("Chat Completion is not supported for embedding model")
6849
}
6950

@@ -209,7 +190,7 @@ type EmbeddingRequest struct {
209190
}
210191

211192
func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {
212-
if req.Model != openai.GPT3Ada002 {
193+
if req.Model != TextEmbedding004 && req.Model != string(openai.AdaEmbeddingV2) {
213194
return nil, errors.New("Embedding is not supported for chat model " + req.Model)
214195
}
215196

@@ -225,7 +206,3 @@ func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {
225206

226207
return content, nil
227208
}
228-
229-
func (req *EmbeddingRequest) ToGenaiModel() string {
230-
return TextEmbedding004
231-
}

0 commit comments

Comments
 (0)