Skip to content

Commit da78e75

Browse files
authored
Add Support for Embeddings Endpoint (#39)
* Add Support for Embeddings Endpoint * Update README with Embeddings Usage
1 parent 9d1fbf6 commit da78e75

File tree

5 files changed

+186
-2
lines changed

5 files changed

+186
-2
lines changed

README.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Gemini-OpenAI-Proxy
22

3-
Gemini-OpenAI-Proxy is a proxy designed to convert the OpenAI API protocol to the Google Gemini Pro protocol. This enables seamless integration of OpenAI-powered functionalities into applications using the Gemini Pro protocol.
3+
Gemini-OpenAI-Proxy is a proxy designed to convert the OpenAI API protocol to the Google Gemini protocol. This enables applications built for the OpenAI API to seamlessly communicate with the Gemini protocol, including support for Chat Completion, Embeddings, and Model(s) endpoints.
44

55
---
66

@@ -51,7 +51,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
5151
3. **Integrate the Proxy into Your Application:**
5252
Modify your application's API requests to target the Gemini-OpenAI-Proxy, providing the acquired Google AI Studio API key as if it were your OpenAI API key.
5353

54-
Example API Request (Assuming the proxy is hosted at `http://localhost:8080`):
54+
Example Chat Completion API Request (Assuming the proxy is hosted at `http://localhost:8080`):
5555
```bash
5656
curl http://localhost:8080/v1/chat/completions \
5757
-H "Content-Type: application/json" \
@@ -97,6 +97,30 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
9797
}'
9898
```
9999

100+
Example Embeddings API Request:
101+
102+
```bash
103+
curl http://localhost:8080/v1/embeddings \
104+
-H "Content-Type: application/json" \
105+
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
106+
-d '{
107+
"model": "ada-002",
108+
"input": "This is a test sentence."
109+
}'
110+
```
111+
112+
You can also pass in multiple input strings as a list:
113+
114+
```bash
115+
curl http://localhost:8080/v1/embeddings \
116+
-H "Content-Type: application/json" \
117+
-H "Authorization: Bearer $YOUR_GOOGLE_AI_STUDIO_API_KEY" \
118+
-d '{
119+
"model": "ada-002",
120+
"input": ["This is a test sentence.", "This is another test sentence"]
121+
}'
122+
```
123+
100124
Model Mapping:
101125

102126
| GPT Model | Gemini Model |
@@ -105,6 +129,7 @@ Gemini-OpenAI-Proxy offers a straightforward way to integrate OpenAI functionali
105129
| gpt-4 | gemini-1.5-flash-latest |
106130
| gpt-4-turbo-preview | gemini-1.5-pro-latest |
107131
| gpt-4-vision-preview | gemini-1.0-pro-vision-latest |
132+
| ada-002 | text-embedding-004 |
108133

109134
If you wish to map `gpt-4-vision-preview` to `gemini-1.5-pro-latest`, you can configure the environment variable `GPT_4_VISION_PREVIEW = gemini-1.5-pro-latest`. This is because `gemini-1.5-pro-latest` now also supports multi-modal data.
110135

api/handler.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ func ModelListHandler(c *gin.Context) {
4848
Object: "model",
4949
OwnedBy: "openai",
5050
},
51+
openai.Model{
52+
CreatedAt: 1686935002,
53+
ID: openai.GPT3Ada002,
54+
Object: "model",
55+
OwnedBy: "openai",
56+
},
5157
},
5258
})
5359
}
@@ -154,3 +160,65 @@ func setEventStreamHeaders(c *gin.Context) {
154160
c.Writer.Header().Set("Transfer-Encoding", "chunked")
155161
c.Writer.Header().Set("X-Accel-Buffering", "no")
156162
}
163+
164+
func EmbeddingProxyHandler(c *gin.Context) {
165+
// Retrieve the Authorization header value
166+
authorizationHeader := c.GetHeader("Authorization")
167+
// Declare a variable to store the OPENAI_API_KEY
168+
var openaiAPIKey string
169+
// Use fmt.Sscanf to extract the Bearer token
170+
_, err := fmt.Sscanf(authorizationHeader, "Bearer %s", &openaiAPIKey)
171+
if err != nil {
172+
c.JSON(http.StatusBadRequest, openai.APIError{
173+
Code: http.StatusBadRequest,
174+
Message: err.Error(),
175+
})
176+
return
177+
}
178+
179+
req := &adapter.EmbeddingRequest{}
180+
// Bind the JSON data from the request to the struct
181+
if err := c.ShouldBindJSON(req); err != nil {
182+
c.JSON(http.StatusBadRequest, openai.APIError{
183+
Code: http.StatusBadRequest,
184+
Message: err.Error(),
185+
})
186+
return
187+
}
188+
189+
messages, err := req.ToGenaiMessages()
190+
if err != nil {
191+
c.JSON(http.StatusBadRequest, openai.APIError{
192+
Code: http.StatusBadRequest,
193+
Message: err.Error(),
194+
})
195+
return
196+
}
197+
198+
ctx := c.Request.Context()
199+
client, err := genai.NewClient(ctx, option.WithAPIKey(openaiAPIKey))
200+
if err != nil {
201+
log.Printf("new genai client error %v\n", err)
202+
c.JSON(http.StatusBadRequest, openai.APIError{
203+
Code: http.StatusBadRequest,
204+
Message: err.Error(),
205+
})
206+
return
207+
}
208+
defer client.Close()
209+
210+
model := req.ToGenaiModel()
211+
gemini := adapter.NewGeminiAdapter(client, model)
212+
213+
resp, err := gemini.GenerateEmbedding(ctx, messages)
214+
if err != nil {
215+
log.Printf("genai generate content error %v\n", err)
216+
c.JSON(http.StatusBadRequest, openai.APIError{
217+
Code: http.StatusBadRequest,
218+
Message: err.Error(),
219+
})
220+
return
221+
}
222+
223+
c.JSON(http.StatusOK, resp)
224+
}

api/router.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ func Register(router *gin.Engine) {
2424

2525
// openai chat
2626
router.POST("/v1/chat/completions", ChatProxyHandler)
27+
28+
// openai embeddings
29+
router.POST("/v1/embeddings", EmbeddingProxyHandler)
2730
}

pkg/adapter/chat.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const (
2020
Gemini1Pro = "gemini-1.0-pro-latest"
2121
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
2222
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
23+
TextEmbedding004 = "text-embedding-004"
2324

2425
genaiRoleUser = "user"
2526
genaiRoleModel = "model"
@@ -239,3 +240,37 @@ func setGenaiModelByOpenaiRequest(model *genai.GenerativeModel, req *ChatComplet
239240
},
240241
}
241242
}
243+
244+
func (g *GeminiAdapter) GenerateEmbedding(
245+
ctx context.Context,
246+
messages []*genai.Content,
247+
) (*openai.EmbeddingResponse, error) {
248+
model := g.client.EmbeddingModel(g.model)
249+
250+
batchEmbeddings := model.NewBatch()
251+
for _, message := range messages {
252+
batchEmbeddings = batchEmbeddings.AddContent(message.Parts...)
253+
}
254+
255+
genaiResp, err := model.BatchEmbedContents(ctx, batchEmbeddings)
256+
if err != nil {
257+
return nil, errors.Wrap(err, "genai generate embeddings error")
258+
}
259+
260+
openaiResp := openai.EmbeddingResponse{
261+
Object: "list",
262+
Data: make([]openai.Embedding, 0, len(genaiResp.Embeddings)),
263+
Model: openai.EmbeddingModel(g.model),
264+
}
265+
266+
for i, genaiEmbedding := range genaiResp.Embeddings {
267+
embedding := openai.Embedding{
268+
Object: "embedding",
269+
Embedding: genaiEmbedding.Values,
270+
Index: i,
271+
}
272+
openaiResp.Data = append(openaiResp.Data, embedding)
273+
}
274+
275+
return &openaiResp, nil
276+
}

pkg/adapter/struct.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ func (req *ChatCompletionRequest) ToGenaiModel() string {
6363
func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) {
6464
if req.Model == openai.GPT4VisionPreview {
6565
return req.toVisionGenaiContent()
66+
} else if req.Model == openai.GPT3Ada002 {
67+
return nil, errors.New("Chat Completion is not supported for embedding model")
6668
}
6769

6870
return req.toStringGenaiContent()
@@ -176,3 +178,54 @@ type CompletionResponse struct {
176178
Model string `json:"model"`
177179
Choices []CompletionChoice `json:"choices"`
178180
}
181+
182+
type StringArray []string
183+
184+
// UnmarshalJSON implements the json.Unmarshaler interface for StringArray.
185+
func (s *StringArray) UnmarshalJSON(data []byte) error {
186+
// Check if the data is a JSON array
187+
if data[0] == '[' {
188+
var arr []string
189+
if err := json.Unmarshal(data, &arr); err != nil {
190+
return err
191+
}
192+
*s = arr
193+
return nil
194+
}
195+
196+
// Check if the data is a JSON string
197+
var str string
198+
if err := json.Unmarshal(data, &str); err != nil {
199+
return err
200+
}
201+
*s = StringArray{str} // Wrap the string in a slice
202+
return nil
203+
}
204+
205+
// EmbeddingRequest represents a request structure for embeddings API.
206+
type EmbeddingRequest struct {
207+
Model string `json:"model" binding:"required"`
208+
Messages StringArray `json:"input" binding:"required,min=1"`
209+
}
210+
211+
func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {
212+
if req.Model != openai.GPT3Ada002 {
213+
return nil, errors.New("Embedding is not supported for chat model " + req.Model)
214+
}
215+
216+
content := make([]*genai.Content, 0, len(req.Messages))
217+
for _, message := range req.Messages {
218+
embedString := []genai.Part{
219+
genai.Text(message),
220+
}
221+
content = append(content, &genai.Content{
222+
Parts: embedString,
223+
})
224+
}
225+
226+
return content, nil
227+
}
228+
229+
func (req *EmbeddingRequest) ToGenaiModel() string {
230+
return TextEmbedding004
231+
}

0 commit comments

Comments
 (0)