|
3 | 3 | package api_test
|
4 | 4 |
|
5 | 5 | import (
|
| 6 | + "bytes" |
6 | 7 | "context"
|
| 8 | + "encoding/json" |
7 | 9 | "fmt"
|
| 10 | + "io" |
8 | 11 | "net/http"
|
9 | 12 | "net/http/httptest"
|
10 | 13 | "net/url"
|
@@ -230,3 +233,83 @@ func TestClientQueryTimeout(t *testing.T) {
|
230 | 233 | assert.True(t, set, "query timeout deadline not set, expected it to be set")
|
231 | 234 | assert.True(t, -d >= 4.8 && -d <= 5.2, "deadline %f, expected between 4.8-5.2s (5s client)", d)
|
232 | 235 | }
|
| 236 | + |
| 237 | +func TestContextPropagation(t *testing.T) { |
| 238 | + // Make sure context values from the request are propagated all the way down to the entity.Store context |
| 239 | + var gotCtx context.Context |
| 240 | + store := mock.EntityStore{} |
| 241 | + store.WithContextFunc = func(ctx context.Context) entity.Store { |
| 242 | + gotCtx = ctx |
| 243 | + return store |
| 244 | + } |
| 245 | + // We're going to test all operations, so we need to set all of these funcs |
| 246 | + store.ReadEntitiesFunc = func(entityType string, q query.Query, f etre.QueryFilter) ([]etre.Entity, error) { |
| 247 | + return testEntitiesWithObjectIDs[0:1], nil |
| 248 | + } |
| 249 | + store.CreateEntitiesFunc = func(op entity.WriteOp, entities []etre.Entity) ([]string, error) { |
| 250 | + return []string{testEntityIds[0]}, nil |
| 251 | + } |
| 252 | + store.UpdateEntitiesFunc = func(op entity.WriteOp, q query.Query, e etre.Entity) ([]etre.Entity, error) { |
| 253 | + return testEntitiesWithObjectIDs[0:1], nil |
| 254 | + } |
| 255 | + store.DeleteEntitiesFunc = func(op entity.WriteOp, q query.Query) ([]etre.Entity, error) { |
| 256 | + return testEntitiesWithObjectIDs[0:1], nil |
| 257 | + } |
| 258 | + store.DeleteLabelFunc = func(op entity.WriteOp, label string) (etre.Entity, error) { |
| 259 | + return testEntitiesWithObjectIDs[0], nil |
| 260 | + } |
| 261 | + |
| 262 | + server := setup(t, defaultConfig, store) |
| 263 | + defer server.ts.Close() |
| 264 | + |
| 265 | + newEntity := etre.Entity{"host": "local"} |
| 266 | + payload, err := json.Marshal(newEntity) |
| 267 | + require.NoError(t, err) |
| 268 | + |
| 269 | + multiPayload, err := json.Marshal([]etre.Entity{newEntity}) |
| 270 | + require.NoError(t, err) |
| 271 | + |
| 272 | + tc := []struct { |
| 273 | + Name string |
| 274 | + Method string |
| 275 | + URL string |
| 276 | + Payload []byte |
| 277 | + }{ |
| 278 | + // mux.Handle("GET "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.getEntitiesHandler))) |
| 279 | + {Name: "getEntitiesHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar")}, |
| 280 | + //mux.Handle("POST "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.postEntitiesHandler))) |
| 281 | + {Name: "postEntitiesHandler", Method: "POST", URL: server.url + etre.API_ROOT + "/entities/" + entityType, Payload: multiPayload}, |
| 282 | + //mux.Handle("PUT "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.putEntitiesHandler))) |
| 283 | + {Name: "putEntitiesHandler", Method: "PUT", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar"), Payload: payload}, |
| 284 | + //mux.Handle("DELETE "+etre.API_ROOT+"/entities/{type}", api.requestWrapper(http.HandlerFunc(api.deleteEntitiesHandler))) |
| 285 | + {Name: "deleteEntitiesHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entities/" + entityType + "?query=" + url.QueryEscape("foo=bar")}, |
| 286 | + //mux.Handle("POST "+etre.API_ROOT+"/entity/{type}", api.requestWrapper(http.HandlerFunc(api.postEntityHandler))) |
| 287 | + {Name: "postEntityHandler", Method: "POST", URL: server.url + etre.API_ROOT + "/entity/" + entityType, Payload: payload}, |
| 288 | + //mux.Handle("GET "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.getEntityHandler)))) |
| 289 | + {Name: "getEntityHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0]}, |
| 290 | + //mux.Handle("PUT "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.putEntityHandler)))) |
| 291 | + {Name: "putEntityHandler", Method: "PUT", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0], Payload: payload}, |
| 292 | + //mux.Handle("GET "+etre.API_ROOT+"/entity/{type}/{id}/labels", api.requestWrapper(api.id(http.HandlerFunc(api.getLabelsHandler)))) |
| 293 | + {Name: "getLabelsHandler", Method: "GET", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0] + "/labels"}, |
| 294 | + //mux.Handle("DELETE "+etre.API_ROOT+"/entity/{type}/{id}", api.requestWrapper(api.id(http.HandlerFunc(api.deleteEntityHandler)))) |
| 295 | + {Name: "deleteEntityHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0]}, |
| 296 | + //mux.Handle("DELETE "+etre.API_ROOT+"/entity/{type}/{id}/labels/{label}", api.requestWrapper(api.id(http.HandlerFunc(api.deleteLabelHandler)))) |
| 297 | + {Name: "deleteLabelHandler", Method: "DELETE", URL: server.url + etre.API_ROOT + "/entity/" + entityType + "/" + testEntityIds[0] + "/labels/foo"}, |
| 298 | + } |
| 299 | + |
| 300 | + for _, tt := range tc { |
| 301 | + t.Run(tt.Name, func(t *testing.T) { |
| 302 | + gotCtx = nil |
| 303 | + r := httptest.NewRequest(tt.Method, tt.URL, nil).WithContext(context.WithValue(context.Background(), "key", tt.Name)) |
| 304 | + r.Body = io.NopCloser(bytes.NewReader(tt.Payload)) |
| 305 | + r.ContentLength = int64(len(tt.Payload)) |
| 306 | + |
| 307 | + w := &httptest.ResponseRecorder{} |
| 308 | + server.api.ServeHTTP(w, r) |
| 309 | + require.True(t, w.Code >= 200 && w.Code < 300, "expected 2xx response code, got %d", w.Code) |
| 310 | + // make sure the context pushed to the store had the right key propagated from the original request |
| 311 | + require.NotNil(t, gotCtx) |
| 312 | + assert.Equal(t, tt.Name, gotCtx.Value("key")) |
| 313 | + }) |
| 314 | + } |
| 315 | +} |
0 commit comments