Skip to content

Commit 10ab726

Browse files
committed
add websocket connection cancellation via done channel
1 parent 0fa3f99 commit 10ab726

File tree

2 files changed

+69
-40
lines changed

2 files changed

+69
-40
lines changed

websocketproxy.go

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,39 @@ var (
2525
DefaultDialer = websocket.DefaultDialer
2626
)
2727

28-
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
29-
// connection and proxies it to another server.
30-
type WebsocketProxy struct {
31-
// Director, if non-nil, is a function that may copy additional request
32-
// headers from the incoming WebSocket connection into the output headers
33-
// which will be forwarded to another server.
34-
Director func(incoming *http.Request, out http.Header)
35-
36-
// Backend returns the backend URL which the proxy uses to reverse proxy
37-
// the incoming WebSocket connection. Request is the initial incoming and
38-
// unmodified request.
39-
Backend func(*http.Request) *url.URL
40-
41-
// Upgrader specifies the parameters for upgrading a incoming HTTP
42-
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
43-
Upgrader *websocket.Upgrader
44-
45-
// Dialer contains options for connecting to the backend WebSocket server.
46-
// If nil, DefaultDialer is used.
47-
Dialer *websocket.Dialer
48-
}
28+
type (
29+
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
30+
// connection and proxies it to another server.
31+
WebsocketProxy struct {
32+
// Director, if non-nil, is a function that may copy additional request
33+
// headers from the incoming WebSocket connection into the output headers
34+
// which will be forwarded to another server.
35+
Director func(incoming *http.Request, out http.Header)
36+
37+
// Backend returns the backend URL which the proxy uses to reverse proxy
38+
// the incoming WebSocket connection. Request is the initial incoming and
39+
// unmodified request.
40+
Backend func(*http.Request) *url.URL
41+
42+
// Upgrader specifies the parameters for upgrading a incoming HTTP
43+
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
44+
Upgrader *websocket.Upgrader
45+
46+
// Dialer contains options for connecting to the backend WebSocket server.
47+
// If nil, DefaultDialer is used.
48+
Dialer *websocket.Dialer
49+
50+
// Done specifies a channel for which all proxied websocket connections
51+
// can be closed on demand by closing the channel.
52+
Done chan struct{}
53+
}
54+
55+
websocketMsg struct {
56+
msgType int
57+
msg []byte
58+
err error
59+
}
60+
)
4961

5062
// ProxyHandler returns a new http.Handler interface that reverse proxies the
5163
// request to the given target.
@@ -174,41 +186,55 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
174186

175187
errClient := make(chan error, 1)
176188
errBackend := make(chan error, 1)
189+
177190
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
178-
for {
191+
websocketMsgRcverC := make(chan websocketMsg, 1)
192+
websocketMsgRcver := func() <-chan websocketMsg {
179193
msgType, msg, err := src.ReadMessage()
180-
if err != nil {
181-
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
182-
if e, ok := err.(*websocket.CloseError); ok {
183-
if e.Code != websocket.CloseNoStatusReceived {
184-
m = websocket.FormatCloseMessage(e.Code, e.Text)
194+
websocketMsgRcverC <- websocketMsg{msgType, msg, err}
195+
return websocketMsgRcverC
196+
}
197+
198+
for {
199+
select {
200+
case websocketMsgRcv := <-websocketMsgRcver():
201+
if websocketMsgRcv.err != nil {
202+
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", websocketMsgRcv.err))
203+
if e, ok := websocketMsgRcv.err.(*websocket.CloseError); ok {
204+
if e.Code != websocket.CloseNoStatusReceived {
205+
m = websocket.FormatCloseMessage(e.Code, e.Text)
206+
}
185207
}
208+
errc <- websocketMsgRcv.err
209+
dst.WriteMessage(websocket.CloseMessage, m)
210+
break
186211
}
187-
errc <- err
212+
err = dst.WriteMessage(websocketMsgRcv.msgType, websocketMsgRcv.msg)
213+
if err != nil {
214+
errc <- err
215+
break
216+
}
217+
case <-w.Done:
218+
m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection")
188219
dst.WriteMessage(websocket.CloseMessage, m)
189220
break
190221
}
191-
err = dst.WriteMessage(msgType, msg)
192-
if err != nil {
193-
errc <- err
194-
break
195-
}
196222
}
197223
}
198224

199225
go replicateWebsocketConn(connPub, connBackend, errClient)
200226
go replicateWebsocketConn(connBackend, connPub, errBackend)
201227

202-
var message string
203228
select {
204229
case err = <-errClient:
205-
message = "websocketproxy: Error when copying from backend to client: %v"
230+
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
231+
log.Printf("websocketproxy: Error when copying from backend to client: %v", err)
232+
}
206233
case err = <-errBackend:
207-
message = "websocketproxy: Error when copying from client to backend: %v"
208-
209-
}
210-
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
211-
log.Printf(message, err)
234+
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
235+
log.Printf("websocketproxy: Error when copying from client to backend: %v", err)
236+
}
237+
case <-w.Done:
212238
}
213239
}
214240

websocketproxy_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func TestProxy(t *testing.T) {
3030
u, _ := url.Parse(backendURL)
3131
proxy := NewProxy(u)
3232
proxy.Upgrader = upgrader
33+
proxy.Done = make(chan struct{})
3334

3435
mux := http.NewServeMux()
3536
mux.Handle("/proxy", proxy)
@@ -121,4 +122,6 @@ func TestProxy(t *testing.T) {
121122
if msg != string(p) {
122123
t.Errorf("expecting: %s, got: %s", msg, string(p))
123124
}
125+
126+
close(proxy.Done)
124127
}

0 commit comments

Comments
 (0)