Skip to content

Commit ee34a7d

Browse files
Merge pull request #370 from 3rabiii/feat/request-id-tracking
feat: implement Request ID middleware for log correlation
2 parents 09560b3 + 1d06b5a commit ee34a7d

File tree

6 files changed

+126
-5
lines changed

6 files changed

+126
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ config.docker.json
1414
tmp/
1515
*.log
1616
.DS_Store
17+
/api

cmd/api/app.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"time"
1414

1515
"github.com/prometheus/client_golang/prometheus/promhttp"
16-
1716
"maglev.onebusaway.org/internal/app"
1817
"maglev.onebusaway.org/internal/appconf"
1918
"maglev.onebusaway.org/internal/clock"
@@ -118,7 +117,8 @@ func CreateServer(coreApp *app.Application, cfg appconf.Config) (*http.Server, *
118117
// Add request logging middleware (outermost)
119118
requestLogger := logging.NewStructuredLogger(os.Stdout, slog.LevelInfo)
120119
requestLogMiddleware := restapi.NewRequestLoggingMiddleware(requestLogger)
121-
handler := requestLogMiddleware(metricsHandler)
120+
121+
handler := restapi.RequestIDMiddleware(requestLogMiddleware(metricsHandler))
122122

123123
srv := &http.Server{
124124
Addr: fmt.Sprintf(":%d", cfg.Port),

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.24.2
55
require (
66
github.com/OneBusAway/go-gtfs v1.1.0
77
github.com/davecgh/go-spew v1.1.1
8+
github.com/google/uuid v1.6.0
89
github.com/klauspost/compress v1.18.0
910
github.com/mattn/go-sqlite3 v1.14.24
1011
github.com/prometheus/client_golang v1.23.2
@@ -25,7 +26,6 @@ require (
2526
github.com/fatih/structtag v1.2.0 // indirect
2627
github.com/go-sql-driver/mysql v1.9.3 // indirect
2728
github.com/google/cel-go v0.26.1 // indirect
28-
github.com/google/uuid v1.6.0 // indirect
2929
github.com/inconshreveable/mousetrap v1.1.0 // indirect
3030
github.com/jackc/pgpassfile v1.0.0 // indirect
3131
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package restapi
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"regexp"
7+
8+
"github.com/google/uuid"
9+
)
10+
11+
type contextKey string
12+
13+
const RequestIDKey contextKey = "request_id"
14+
15+
var validRequestIDRegex = regexp.MustCompile(`^[a-zA-Z0-9-._:]+$`)
16+
17+
func RequestIDMiddleware(next http.Handler) http.Handler {
18+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
reqID := r.Header.Get("X-Request-ID")
20+
21+
if reqID == "" || len(reqID) > 128 || !validRequestIDRegex.MatchString(reqID) {
22+
reqID = uuid.NewString()
23+
}
24+
25+
w.Header().Set("X-Request-ID", reqID)
26+
27+
ctx := context.WithValue(r.Context(), RequestIDKey, reqID)
28+
29+
next.ServeHTTP(w, r.WithContext(ctx))
30+
})
31+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package restapi
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestRequestIDMiddleware(t *testing.T) {
13+
t.Run("should generate request ID if missing", func(t *testing.T) {
14+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
reqID, ok := r.Context().Value(RequestIDKey).(string)
16+
assert.True(t, ok, "Context should contain request ID")
17+
assert.NotEmpty(t, reqID, "Request ID should not be empty")
18+
})
19+
20+
handlerToTest := RequestIDMiddleware(nextHandler)
21+
22+
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
23+
rec := httptest.NewRecorder()
24+
25+
handlerToTest.ServeHTTP(rec, req)
26+
27+
respID := rec.Header().Get("X-Request-ID")
28+
assert.NotEmpty(t, respID, "Response header should contain X-Request-ID")
29+
assert.Regexp(t, `^[0-9a-f-]{36}$`, respID)
30+
})
31+
32+
t.Run("should preserve existing valid request ID", func(t *testing.T) {
33+
existingID := "my-custom-trace-id-123"
34+
35+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36+
reqID, ok := r.Context().Value(RequestIDKey).(string)
37+
assert.True(t, ok)
38+
assert.Equal(t, existingID, reqID)
39+
})
40+
41+
handlerToTest := RequestIDMiddleware(nextHandler)
42+
43+
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
44+
req.Header.Set("X-Request-ID", existingID)
45+
rec := httptest.NewRecorder()
46+
47+
handlerToTest.ServeHTTP(rec, req)
48+
49+
assert.Equal(t, existingID, rec.Header().Get("X-Request-ID"))
50+
})
51+
52+
t.Run("should replace invalid request ID", func(t *testing.T) {
53+
testCases := []struct {
54+
name string
55+
invalidID string
56+
}{
57+
{
58+
name: "ID too long (>128 chars)",
59+
invalidID: strings.Repeat("a", 129),
60+
},
61+
{
62+
name: "ID contains invalid characters",
63+
invalidID: "bad-id-<script>",
64+
},
65+
}
66+
67+
for _, tc := range testCases {
68+
t.Run(tc.name, func(t *testing.T) {
69+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
70+
reqID, ok := r.Context().Value(RequestIDKey).(string)
71+
assert.True(t, ok)
72+
assert.NotEqual(t, tc.invalidID, reqID)
73+
assert.Regexp(t, `^[0-9a-f-]{36}$`, reqID)
74+
})
75+
76+
handlerToTest := RequestIDMiddleware(nextHandler)
77+
78+
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
79+
req.Header.Set("X-Request-ID", tc.invalidID)
80+
rec := httptest.NewRecorder()
81+
82+
handlerToTest.ServeHTTP(rec, req)
83+
})
84+
}
85+
})
86+
}

internal/restapi/request_logging_middleware.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ func NewRequestLoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Ha
4141
// Log the request
4242
duration := time.Since(start)
4343

44+
reqID, _ := r.Context().Value(RequestIDKey).(string)
45+
4446
logging.LogHTTPRequest(logger,
4547
r.Method,
46-
r.URL.Path, // Path without query parameters
48+
r.URL.Path,
4749
wrapped.statusCode,
48-
float64(duration.Nanoseconds())/1e6, // Convert to milliseconds
50+
float64(duration.Nanoseconds())/1e6,
51+
slog.String("request_id", reqID),
4952
slog.String("user_agent", r.Header.Get("User-Agent")),
5053
slog.String("component", "http_server"))
5154
})

0 commit comments

Comments
 (0)