-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathpassthrough.go
More file actions
119 lines (105 loc) · 4.01 KB
/
passthrough.go
File metadata and controls
119 lines (105 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package aibridge
import (
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"cdr.dev/slog/v3"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/provider"
"github.com/coder/aibridge/tracing"
"github.com/coder/quartz"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
// by a [intercept.Provider].
func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if m != nil {
m.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1)
}
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
attribute.String(tracing.PassthroughURL, r.URL.String()),
attribute.String(tracing.PassthroughMethod, r.Method),
))
defer span.End()
upURL, err := url.Parse(provider.BaseURL())
if err != nil {
logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err))
http.Error(w, "request error", http.StatusBadGateway)
span.SetStatus(codes.Error, "failed to parse provider base URL: "+err.Error())
return
}
// Append the request path to the upstream base path.
reqPath, err := url.JoinPath(upURL.Path, r.URL.Path)
if err != nil {
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstreamPath", upURL.Path), slog.F("requestPath", r.URL.Path))
http.Error(w, "failed to join upstream path", http.StatusInternalServerError)
span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error())
return
}
// Ensure leading slash, proxied requests should have absolute paths.
// JoinPath can return relative paths, eg. when upURL path is empty.
if len(reqPath) == 0 || reqPath[0] != '/' {
reqPath = "/" + reqPath
}
// Build a reverse proxy to the upstream.
proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) {
// Set scheme/host to upstream.
req.URL.Scheme = upURL.Scheme
req.URL.Host = upURL.Host
req.URL.Path = reqPath
req.URL.RawPath = ""
// Preserve query string.
req.URL.RawQuery = r.URL.RawQuery
// Set Host header for upstream.
req.Host = upURL.Host
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String()))
// Copy headers from client.
req.Header = r.Header.Clone()
// Standard proxy headers.
host, _, herr := net.SplitHostPort(r.RemoteAddr)
if herr != nil {
host = r.RemoteAddr
}
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
req.Header.Set("X-Forwarded-For", prior+", "+host)
} else {
req.Header.Set("X-Forwarded-For", host)
}
req.Header.Set("X-Forwarded-Host", r.Host)
if r.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
} else {
req.Header.Set("X-Forwarded-Proto", "http")
}
// Avoid default Go user-agent if none provided.
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}
// Inject provider auth.
provider.InjectAuthHeader(&req.Header)
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
},
}
// Transport tuned for streaming (no response header timeout).
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal())
proxy.ServeHTTP(w, r)
}
}