2222import java .security .Principal ;
2323import java .util .*;
2424import java .util .concurrent .ConcurrentHashMap ;
25- import java .util .concurrent . ConcurrentMap ;
25+ import java .util .stream . Collectors ;
2626
2727import org .apache .commons .logging .Log ;
2828import org .apache .commons .logging .LogFactory ;
2929import org .springframework .context .ApplicationEventPublisher ;
3030import org .springframework .context .ApplicationEventPublisherAware ;
31+ import org .springframework .context .SmartLifecycle ;
3132import org .springframework .messaging .Message ;
3233import org .springframework .messaging .MessageChannel ;
3334import org .springframework .web .reactive .socket .CloseStatus ;
5859import reactor .core .publisher .Mono ;
5960
6061public class WampWebSocketHandler
61- implements WebSocketHandler , ApplicationEventPublisherAware {
62+ implements WebSocketHandler , ApplicationEventPublisherAware , SmartLifecycle {
6263
6364 private static final Log logger = LogFactory .getLog (WampWebSocketHandler .class );
6465
66+ private static final String WAMP_SESSION_ID = "wamp2spring.session.id" ;
67+ private static final String WAMP_PRINCIPAL = "wamp2spring.principal" ;
68+
6569 public static final String JSON_PROTOCOL = "wamp.2.json" ;
6670
6771 public static final String MSGPACK_PROTOCOL = "wamp.2.msgpack" ;
@@ -83,14 +87,16 @@ public class WampWebSocketHandler
8387
8488 private final List <WampRole > roles ;
8589
86- private final ConcurrentMap <String , Long > webSocketId2WampSessionId = new ConcurrentHashMap <>();
87-
8890 private final MessageChannel clientInboundChannel ;
8991
9092 private final MessageChannel clientOutboundChannel ;
9193
9294 private ApplicationEventPublisher applicationEventPublisher ;
9395
96+ private volatile boolean isRunning ;
97+
98+ private final Set <WebSocketSession > webSocketSessions = ConcurrentHashMap .newKeySet ();
99+
94100 public WampWebSocketHandler (JsonFactory jsonFactory , JsonFactory msgpackFactory ,
95101 JsonFactory cborFactory , JsonFactory smileFactory ,
96102 MessageChannel clientOutboundChannel , MessageChannel clientInboundChannel ,
@@ -128,32 +134,56 @@ public List<String> getSubProtocols() {
128134
129135 @ Override
130136 public Mono <Void > handle (WebSocketSession session ) {
131- return session .getHandshakeInfo ().getPrincipal ()
132- .map (Optional ::of ).defaultIfEmpty (Optional .empty ())
133- .flatMap (optPrincipal -> {
134- Principal principal = optPrincipal .orElse (null );
135-
136- Mono <Void > receiveFlux = session .receive ()
137- .doOnNext (inMsg -> handleIncomingMessage (inMsg , session , principal ))
138- .then ();
139-
140- Mono <Void > sendFlux = session .send (Flux .from (MessageChannelReactiveUtils .toPublisher (this .clientOutboundChannel ))
141- .filter (msg -> resolveSessionId (msg ).equals (session .getId ()))
142- .map (msg -> handleOutgoingMessage (msg , session ))
143- ).doFinally (sig -> {
144- Long wampSessionId = this .webSocketId2WampSessionId .get (session .getId ());
145- if (wampSessionId != null ) {
146- this .applicationEventPublisher .publishEvent (new WampDisconnectEvent (wampSessionId , session .getId (), principal ));
147- this .webSocketId2WampSessionId .remove (session .getId ());
148- }
149- });
150-
151- return Mono .when (receiveFlux , sendFlux );
152- });
137+ if (!this .isRunning ) {
138+ return session .close (CloseStatus .GOING_AWAY );
139+ }
140+
141+ webSocketSessions .add (session );
142+
143+ return Mono .when (
144+ session .getHandshakeInfo ().getPrincipal ().doOnNext (p -> session .getAttributes ().put (WAMP_PRINCIPAL , p )),
145+ session .send (Flux .from (MessageChannelReactiveUtils .toPublisher (this .clientOutboundChannel ))
146+ .filter (msg -> resolveSessionId (msg ).equals (session .getId ()))
147+ .map (msg -> handleOutgoingMessage (msg , session ))
148+ ),
149+ session .receive ().doOnNext (inMsg -> handleIncomingMessage (inMsg , session ))
150+ ).doFinally (sig -> {
151+ webSocketSessions .remove (session );
152+
153+ Long wampSessionId = (Long ) session .getAttributes ().get (WAMP_SESSION_ID );
154+ Principal principalAttr = (Principal ) session .getAttributes ().get (WAMP_PRINCIPAL );
155+
156+ if (wampSessionId != null ) {
157+ this .applicationEventPublisher .publishEvent (new WampDisconnectEvent (wampSessionId , session .getId (), principalAttr ));
158+ }
159+ });
160+ }
161+
162+ @ Override
163+ public void start () {
164+ if (!this .isRunning ()) {
165+ this .isRunning = true ;
166+ }
153167 }
154168
155- public void handleIncomingMessage (WebSocketMessage inMsg , WebSocketSession session , Principal principal ) {
169+ @ Override
170+ public void stop () {
171+ if (this .isRunning ()) {
172+ this .isRunning = false ;
173+
174+ Flux .fromIterable (webSocketSessions )
175+ .flatMap (session -> session .close (CloseStatus .GOING_AWAY ))
176+ .doFinally (sig -> webSocketSessions .clear ())
177+ .subscribe ();
178+ }
179+ }
156180
181+ @ Override
182+ public boolean isRunning () {
183+ return this .isRunning ;
184+ }
185+
186+ private void handleIncomingMessage (WebSocketMessage inMsg , WebSocketSession session ) {
157187 try {
158188 WampMessage wampMessage = null ;
159189
@@ -200,11 +230,13 @@ else if (WampWebSocketHandler.CBOR_PROTOCOL.equals(acceptedProtocol)) {
200230 return ;
201231 }
202232
233+ Principal principal = (Principal ) session .getAttributes ().get (WAMP_PRINCIPAL );
234+
203235 wampMessage .setHeader (WampMessageHeader .WEBSOCKET_SESSION_ID ,
204236 session .getId ());
205237 wampMessage .setHeader (WampMessageHeader .PRINCIPAL , principal );
206238 wampMessage .setHeader (WampMessageHeader .WAMP_SESSION_ID ,
207- this . webSocketId2WampSessionId . get ( session .getId () ));
239+ session .getAttributes (). get ( WAMP_SESSION_ID ));
208240
209241 if (wampMessage instanceof HelloMessage ) {
210242 // If this is a helloMessage sent during a running session close the
@@ -215,8 +247,13 @@ else if (WampWebSocketHandler.CBOR_PROTOCOL.equals(acceptedProtocol)) {
215247 }
216248
217249 long newWampSessionId = IdGenerator .newRandomId (
218- new HashSet <>(this .webSocketId2WampSessionId .values ()));
219- this .webSocketId2WampSessionId .put (session .getId (), newWampSessionId );
250+ webSocketSessions .stream ()
251+ .map (webSocketSession -> (Long ) webSocketSession .getAttributes ().get (WAMP_SESSION_ID ))
252+ .filter (Objects ::nonNull )
253+ .collect (Collectors .toSet ())
254+ );
255+
256+ session .getAttributes ().put (WAMP_SESSION_ID , newWampSessionId );
220257
221258 WelcomeMessage welcomeMessage = new WelcomeMessage (
222259 (HelloMessage ) wampMessage , newWampSessionId , this .roles );
0 commit comments