Skip to content

Commit bd36c45

Browse files
Support for extra_body parameter for embeddings API (#906)
* support for extra_body parameter for embeddings API * done linting * added unit tests * improved code coverage and removed unnecessary checks * test cleanup * updated body map creation code * code coverage * minor change * updated testcase comment
1 parent 3bb1014 commit bd36c45

File tree

3 files changed

+90
-2
lines changed

3 files changed

+90
-2
lines changed

client.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ func withBody(body any) requestOption {
8484
}
8585
}
8686

87+
func withExtraBody(extraBody map[string]any) requestOption {
88+
return func(args *requestOptions) {
89+
// Assert that args.body is a map[string]any.
90+
bodyMap, ok := args.body.(map[string]any)
91+
if ok {
92+
// If it's a map[string]any then only add extraBody
93+
// fields to args.body otherwise keep only fields in request struct.
94+
for key, value := range extraBody {
95+
bodyMap[key] = value
96+
}
97+
}
98+
}
99+
}
100+
87101
func withContentType(contentType string) requestOption {
88102
return func(args *requestOptions) {
89103
args.header.Set("Content-Type", contentType)

embeddings.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/base64"
66
"encoding/binary"
7+
"encoding/json"
78
"errors"
89
"math"
910
"net/http"
@@ -160,6 +161,9 @@ type EmbeddingRequest struct {
160161
// Dimensions The number of dimensions the resulting output embeddings should have.
161162
// Only supported in text-embedding-3 and later models.
162163
Dimensions int `json:"dimensions,omitempty"`
164+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
165+
// in the request body that may not be explicitly defined in this struct.
166+
ExtraBody map[string]any `json:"extra_body,omitempty"`
163167
}
164168

165169
func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct {
187191
// Dimensions The number of dimensions the resulting output embeddings should have.
188192
// Only supported in text-embedding-3 and later models.
189193
Dimensions int `json:"dimensions,omitempty"`
194+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
195+
// in the request body that may not be explicitly defined in this struct.
196+
ExtraBody map[string]any `json:"extra_body,omitempty"`
190197
}
191198

192199
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
@@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
196203
User: r.User,
197204
EncodingFormat: r.EncodingFormat,
198205
Dimensions: r.Dimensions,
206+
ExtraBody: r.ExtraBody,
199207
}
200208
}
201209

@@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct {
219227
// Dimensions The number of dimensions the resulting output embeddings should have.
220228
// Only supported in text-embedding-3 and later models.
221229
Dimensions int `json:"dimensions,omitempty"`
230+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
231+
// in the request body that may not be explicitly defined in this struct.
232+
ExtraBody map[string]any `json:"extra_body,omitempty"`
222233
}
223234

224235
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
@@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
228239
User: r.User,
229240
EncodingFormat: r.EncodingFormat,
230241
Dimensions: r.Dimensions,
242+
ExtraBody: r.ExtraBody,
231243
}
232244
}
233245

@@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings(
241253
conv EmbeddingRequestConverter,
242254
) (res EmbeddingResponse, err error) {
243255
baseReq := conv.Convert()
256+
257+
// The body map is used to dynamically construct the request payload for the embedding API.
258+
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
259+
// based on their presence, avoiding unnecessary or empty fields in the request.
260+
extraBody := baseReq.ExtraBody
261+
baseReq.ExtraBody = nil
262+
263+
// Serialize baseReq to JSON
264+
jsonData, err := json.Marshal(baseReq)
265+
if err != nil {
266+
return
267+
}
268+
269+
// Deserialize JSON to map[string]any
270+
var body map[string]any
271+
_ = json.Unmarshal(jsonData, &body)
272+
244273
req, err := c.newRequest(
245274
ctx,
246275
http.MethodPost,
247276
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
248-
withBody(baseReq),
277+
withBody(body), // Main request body.
278+
withExtraBody(extraBody), // Merge ExtraBody fields.
249279
)
250280
if err != nil {
251281
return

embeddings_test.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) {
5151
t.Fatalf("Expected embedding request to contain model field")
5252
}
5353

54+
// test embedding request with strings and extra_body param
55+
embeddingReqWithExtraBody := openai.EmbeddingRequest{
56+
Input: []string{
57+
"The food was delicious and the waiter",
58+
"Other examples of embedding request",
59+
},
60+
Model: model,
61+
ExtraBody: map[string]any{
62+
"input_type": "query",
63+
"truncate": "NONE",
64+
},
65+
}
66+
marshaled, err = json.Marshal(embeddingReqWithExtraBody)
67+
checks.NoError(t, err, "Could not marshal embedding request")
68+
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
69+
t.Fatalf("Expected embedding request to contain model field")
70+
}
71+
5472
// test embedding request with strings
5573
embeddingReqStrings := openai.EmbeddingRequestStrings{
5674
Input: []string{
@@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) {
124142
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
125143
}
126144

127-
// test create embeddings with strings (simple embedding request)
145+
// test create embeddings with strings (ExtraBody in request)
146+
res, err = client.CreateEmbeddings(
147+
context.Background(),
148+
openai.EmbeddingRequest{
149+
ExtraBody: map[string]any{
150+
"input_type": "query",
151+
"truncate": "NONE",
152+
},
153+
Dimensions: 1,
154+
},
155+
)
156+
checks.NoError(t, err, "CreateEmbeddings error")
157+
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
158+
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
159+
}
160+
161+
// test create embeddings with strings (ExtraBody in request and )
162+
_, err = client.CreateEmbeddings(
163+
context.Background(),
164+
openai.EmbeddingRequest{
165+
Input: make(chan int), // Channels are not serializable
166+
Model: "example_model",
167+
},
168+
)
169+
checks.HasError(t, err, "CreateEmbeddings error")
170+
171+
// test failed (Serialize JSON error)
128172
res, err = client.CreateEmbeddings(
129173
context.Background(),
130174
openai.EmbeddingRequest{

0 commit comments

Comments
 (0)