diff --git a/hapi-plugin-websocket.js b/hapi-plugin-websocket.js index dc4a399..a181664 100644 --- a/hapi-plugin-websocket.js +++ b/hapi-plugin-websocket.js @@ -215,9 +215,10 @@ const register = async (server, pluginOptions) => { delete headers["accept-encoding"] /* optionally inject an empty initial message */ + let initially; if (routeOptions.initially) { /* inject incoming WebSocket message as a simulated HTTP request */ - const response = await server.inject({ + initially = server.inject({ /* simulate the hard-coded POST request */ method: "POST", @@ -233,25 +234,32 @@ const register = async (server, pluginOptions) => { plugins: { websocket: { mode: "websocket", ctx, wss, ws, wsf, req, peers, initially: true } } - }) - - /* any HTTP redirection, client error or server error response - leads to an immediate WebSocket connection drop */ - if (response.statusCode >= 300) { - const annotation = `(HAPI handler responded with HTTP status ${response.statusCode})` - if (response.statusCode < 400) - ws.close(1002, `Protocol Error ${annotation}`) - else if (response.statusCode < 500) - ws.close(1008, `Policy Violation ${annotation}`) - else - ws.close(1011, `Server Error ${annotation}`) - } + }).then(response => { + /* any HTTP redirection, client error or server error response + leads to an immediate WebSocket connection drop */ + if (response.statusCode < 300) { + return true; + } else { + const annotation = `(HAPI handler responded with HTTP status ${response.statusCode})` + if (response.statusCode < 400) + ws.close(1002, `Protocol Error ${annotation}`) + else if (response.statusCode < 500) + ws.close(1008, `Policy Violation ${annotation}`) + else + ws.close(1011, `Server Error ${annotation}`) + return false; + } + }); } /* hook into WebSocket message retrieval */ if (routeOptions.frame === true) { /* framed WebSocket communication (correlated request/reply) */ wsf.on("message", async (ev) => { + if (initially && !(await initially)) { + return; + } + /* allow application to hook into raw WebSocket frame processing */ routeOptions.frameMessage.call(ctx, { ctx, wss, ws, wsf, req, peers }, ev.frame) @@ -294,6 +302,10 @@ const register = async (server, pluginOptions) => { else { /* plain WebSocket communication (uncorrelated request/response) */ ws.on("message", async (message) => { + if (initially && !(await initially)) { + return; + } + /* inject incoming WebSocket message as a simulated HTTP request */ const response = await server.inject({ /* simulate the hard-coded POST request */ @@ -322,6 +334,7 @@ const register = async (server, pluginOptions) => { /* hook into WebSocket disconnection */ ws.on("close", () => { /* allow application to hook into WebSocket disconnection */ + /* note that this is done even if the `initially` handler closes the connection */ routeOptions.disconnect.call(ctx, { ctx, wss, ws, wsf, req, peers }) /* stop tracking the peer */ @@ -330,6 +343,7 @@ const register = async (server, pluginOptions) => { }) /* allow application to hook into WebSocket error processing */ + /* note that this is done even if the `initially` handler closes the connection */ ws.on("error", (error) => { routeOptions.error.call(ctx, { ctx, wss, ws, wsf, req, peers }, error) })