diff --git a/README.md b/README.md index 526bb43..83cf081 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# WebsocketProxy [![GoDoc](https://godoc.org/github.com/koding/websocketproxy?status.svg)](https://godoc.org/github.com/koding/websocketproxy) [![Build Status](https://travis-ci.org/koding/websocketproxy.svg)](https://travis-ci.org/koding/websocketproxy) +# WebsocketProxy [![GoDoc](https://godoc.org/github.com/sudiptadeb/websocketproxy?status.svg)](https://godoc.org/github.com/sudiptadeb/websocketproxy) [![Build Status](https://travis-ci.org/sudiptadeb/websocketproxy.svg)](https://travis-ci.org/sudiptadeb/websocketproxy) WebsocketProxy is an http.Handler interface build on top of [gorilla/websocket](https://github.com/gorilla/websocket) that you can plug @@ -7,7 +7,7 @@ into your existing Go webserver to provide WebSocket reverse proxy. ## Install ```bash -go get github.com/koding/websocketproxy +go get github.com/sudiptadeb/websocketproxy ``` ## Example @@ -22,7 +22,7 @@ import ( "net/http" "net/url" - "github.com/koding/websocketproxy" + "github.com/sudiptadeb/websocketproxy" ) var ( diff --git a/websocketproxy.go b/websocketproxy.go index 63d39ba..28cf7c3 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -2,6 +2,7 @@ package websocketproxy import ( + "errors" "fmt" "io" "log" @@ -23,6 +24,8 @@ var ( // DefaultDialer is a dialer with all fields set to the default zero values. DefaultDialer = websocket.DefaultDialer + + errNilChannelClose = errors.New("trying to close nil channel") ) // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket @@ -45,6 +48,10 @@ type WebsocketProxy struct { // Dialer contains options for connecting to the backend WebSocket server. // If nil, DefaultDialer is used. Dialer *websocket.Dialer + + // Stop channels to close the websocket on demand + stopClientChan chan struct{} + stopBackendChan chan struct{} } // ProxyHandler returns a new http.Handler interface that reverse proxies the @@ -65,6 +72,22 @@ func NewProxy(target *url.URL) *WebsocketProxy { return &WebsocketProxy{Backend: backend} } +// Stop websocket proxy on demand +func (w *WebsocketProxy) CloseProxy() (err error) { + err = nil + if w.stopBackendChan != nil { + close(w.stopBackendChan) + } else { + err = errNilChannelClose + } + if w.stopClientChan != nil { + close(w.stopClientChan) + } else { + err = errNilChannelClose + } + return err +} + // ServeHTTP implements the http.Handler that proxies WebSocket connections. func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if w.Backend == nil { @@ -157,14 +180,14 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { upgrader = DefaultUpgrader } - // Only pass those headers to the upgrader. + // passing all headers except those which can become duplicate upgradeHeader := http.Header{} - if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" { - upgradeHeader.Set("Sec-Websocket-Protocol", hdr) - } - if hdr := resp.Header.Get("Set-Cookie"); hdr != "" { - upgradeHeader.Set("Set-Cookie", hdr) - } + copyHeader(upgradeHeader, resp.Header) + + // These are extra header which the upgrader actually sets itself so need to remove these to avoid duplicate headers + upgradeHeader.Del("Connection") + upgradeHeader.Del("Upgrade") + upgradeHeader.Del("Sec-WebSocket-Accept") // Now upgrade the existing incoming request to a WebSocket connection. // Also pass the header that we gathered from the Dial handshake. @@ -177,30 +200,42 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error, stopChan chan struct{}) { for { - msgType, msg, err := src.ReadMessage() - if err != nil { - m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) - if e, ok := err.(*websocket.CloseError); ok { - if e.Code != websocket.CloseNoStatusReceived { - m = websocket.FormatCloseMessage(e.Code, e.Text) + // do until stopChan gets any message + select { + default: + msgType, msg, err := src.ReadMessage() + if err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) + if e, ok := err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } } + errc <- err + dst.WriteMessage(websocket.CloseMessage, m) + break } - errc <- err - dst.WriteMessage(websocket.CloseMessage, m) - break - } - err = dst.WriteMessage(msgType, msg) - if err != nil { - errc <- err - break + err = dst.WriteMessage(msgType, msg) + if err != nil { + errc <- err + break + } + case <-stopChan: + dst.WriteMessage(websocket.CloseMessage, []byte("Closed by proxy")) + return } + } } - go replicateWebsocketConn(connPub, connBackend, errClient) - go replicateWebsocketConn(connBackend, connPub, errBackend) + // initiate the stop channels + w.stopClientChan = make(chan struct{}) + w.stopBackendChan = make(chan struct{}) + + go replicateWebsocketConn(connPub, connBackend, errClient, w.stopClientChan) + go replicateWebsocketConn(connBackend, connPub, errBackend, w.stopBackendChan) var message string select {