Skip to content

Commit 959c0f7

Browse files
authored
Propagate request context to api handlers (#73)
The request context may contain important values, such as trace IDs inserted by e.g. DataDog or other tracing middleware. Ensure the context is pushed through the entire stack to maintain these values.
1 parent 56b1811 commit 959c0f7

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

api/api_gomux.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,13 @@ func (api *API) requestWrapper(next http.Handler) http.Handler {
227227
// --------------------------------------------------------------
228228
var ctx context.Context
229229
var cancel context.CancelFunc
230-
queryTimeout := r.Header.Get(etre.QUERY_TIMEOUT_HEADER) // explicit
231-
if queryTimeout == "" {
232-
ctx, cancel = context.WithTimeout(context.Background(), api.queryTimeout)
233-
} else {
234-
d, err := time.ParseDuration(queryTimeout)
230+
231+
queryTimeout := api.queryTimeout
232+
if qth := r.Header.Get(etre.QUERY_TIMEOUT_HEADER); qth != "" {
233+
d, err := time.ParseDuration(qth)
235234
if err != nil {
236235
err := etre.Error{
237-
Message: fmt.Sprintf("invalid %s header: %s: %s", etre.QUERY_TIMEOUT_HEADER, queryTimeout, err),
236+
Message: fmt.Sprintf("invalid %s header: %s: %s", etre.QUERY_TIMEOUT_HEADER, qth, err),
238237
Type: "invalid-query-timeout",
239238
HTTPStatus: http.StatusBadRequest,
240239
}
@@ -245,8 +244,10 @@ func (api *API) requestWrapper(next http.Handler) http.Handler {
245244
}
246245
return
247246
}
248-
ctx, cancel = context.WithTimeout(context.Background(), d)
247+
queryTimeout = d
249248
}
249+
ctx, cancel = context.WithTimeout(r.Context(), queryTimeout)
250+
250251
defer cancel() // don't leak
251252
t0 := time.Now()
252253

api/api_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
package api_test
44

55
import (
6+
"bytes"
67
"context"
8+
"encoding/json"
79
"fmt"
10+
"io"
811
"net/http"
912
"net/http/httptest"
1013
"net/url"
@@ -230,3 +233,83 @@ func TestClientQueryTimeout(t *testing.T) {
230233
assert.True(t, set, "query timeout deadline not set, expected it to be set")
231234
assert.True(t, -d >= 4.8 && -d <= 5.2, "deadline %f, expected between 4.8-5.2s (5s client)", d)
232235
}
236+
237+
func TestContextPropagation(t *testing.T) {
238+
// Make sure context values from the request are propagated all the way down to the entity.Store context
239+
var gotCtx context.Context
240+
store := mock.EntityStore{}
241+
store.WithContextFunc = func(ctx context.Context) entity.Store {
242+
gotCtx = ctx
243+
return store
244+
}
245+
// We're going to test all operations, so we need to set all of these funcs
246+
store.ReadEntitiesFunc = func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) {
247+
return testEntitiesWithObjectIDs[0:1], nil
248+
}
249+
store.CreateEntitiesFunc = func(op entity.WriteOp, entities []etre.Entity) ([]string, error) {
250+
return []string{testEntityIds[0]}, nil
251+
}
252+
store.UpdateEntitiesFunc = func(op entity.WriteOp, q query.Query, e etre.Entity) ([]etre.Entity, error) {
253+
return testEntitiesWithObjectIDs[0:1], nil
254+
}
255+
store.DeleteEntitiesFunc = func(op entity.WriteOp, q query.Query) ([]etre.Entity, error) {
256+
return testEntitiesWithObjectIDs[0:1], nil
257+
}
258+
store.DeleteLabelFunc = func(op entity.WriteOp, label string) (etre.Entity, error) {
259+
return testEntitiesWithObjectIDs[0], nil
260+
}
261+
262+
server := setup(t, defaultConfig, store)
263+
defer server.ts.Close()
264+
265+
newEntity := etre.Entity{"host": "local"}
266+
payload, err := json.Marshal(newEntity)
267+
require.NoError(t, err)
268+
269+
multiPayload, err := json.Marshal([]etre.Entity{newEntity})
270+
require.NoError(t, err)
271+
272+
tc := []struct {
273+
Name string
274+
Method string
275+
URL string
276+
Payload []byte
277+
}{
278+
// mux.Handle("GET "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.getEntitiesHandler)))
279+
{Name: "getEntitiesHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar")},
280+
//mux.Handle("POST "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.postEntitiesHandler)))
281+
{Name: "postEntitiesHandler", Method: "POST", URL: server.url + etre.API_ROOT + "/entities/" + entityType, Payload: multiPayload},
282+
//mux.Handle("PUT "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.putEntitiesHandler)))
283+
{Name: "putEntitiesHandler", Method: "PUT", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar"), Payload: payload},
284+
//mux.Handle("DELETE "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.deleteEntitiesHandler)))
285+
{Name: "deleteEntitiesHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar")},
286+
//mux.Handle("POST "+etre.API_ROOT+"/entity/{type}", api.requestWrapper(http.HandlerFunc(api.postEntityHandler)))
287+
{Name: "postEntityHandler", Method: "POST", URL: server.url + etre.API_ROOT + "/entity/" + entityType, Payload: payload},
288+
//mux.Handle("GET "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.getEntityHandler))))
289+
{Name: "getEntityHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0]},
290+
//mux.Handle("PUT "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.putEntityHandler))))
291+
{Name: "putEntityHandler", Method: "PUT", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0], Payload: payload},
292+
//mux.Handle("GET "+etre.API_ROOT+"/entity/{type}/{id}/labels", api.requestWrapper(api.id(http.HandlerFunc(api.getLabelsHandler))))
293+
{Name: "getLabelsHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0] + "/labels"},
294+
//mux.Handle("DELETE "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.deleteEntityHandler))))
295+
{Name: "deleteEntityHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0]},
296+
//mux.Handle("DELETE "+etre.API_ROOT+"/entity/{type}/{id}/labels/{label}", api.requestWrapper(api.id(http.HandlerFunc(api.deleteLabelHandler))))
297+
{Name: "deleteLabelHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0] + "/labels/foo"},
298+
}
299+
300+
for _, tt := range tc {
301+
t.Run(tt.Name, func(t *testing.T) {
302+
gotCtx = nil
303+
r := httptest.NewRequest(tt.Method, tt.URL, nil).WithContext(context.WithValue(context.Background(), "key", tt.Name))
304+
r.Body = io.NopCloser(bytes.NewReader(tt.Payload))
305+
r.ContentLength = int64(len(tt.Payload))
306+
307+
w := &httptest.ResponseRecorder{}
308+
server.api.ServeHTTP(w, r)
309+
require.True(t, w.Code >= 200 && w.Code < 300, "expected 2xx response code, got %d", w.Code)
310+
// make sure the context pushed to the store had the right key propagated from the original request
311+
require.NotNil(t, gotCtx)
312+
assert.Equal(t, tt.Name, gotCtx.Value("key"))
313+
})
314+
}
315+
}

0 commit comments

Comments
 (0)