Skip to content

Commit cbbcccb

Browse files
authored
Merge pull request #17 from vzhn/feature/16_graceful_shutdown
#16: close sessions when the application is shutting down
2 parents 894a79e + 35838b3 commit cbbcccb

File tree

1 file changed

+67
-30
lines changed

1 file changed

+67
-30
lines changed

wamp2spring-reactive/src/main/java/ch/rasc/wamp2spring/reactive/WampWebSocketHandler.java

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import java.security.Principal;
2323
import java.util.*;
2424
import java.util.concurrent.ConcurrentHashMap;
25-
import java.util.concurrent.ConcurrentMap;
25+
import java.util.stream.Collectors;
2626

2727
import org.apache.commons.logging.Log;
2828
import org.apache.commons.logging.LogFactory;
2929
import org.springframework.context.ApplicationEventPublisher;
3030
import org.springframework.context.ApplicationEventPublisherAware;
31+
import org.springframework.context.SmartLifecycle;
3132
import org.springframework.messaging.Message;
3233
import org.springframework.messaging.MessageChannel;
3334
import org.springframework.web.reactive.socket.CloseStatus;
@@ -58,10 +59,13 @@
5859
import reactor.core.publisher.Mono;
5960

6061
public class WampWebSocketHandler
61-
implements WebSocketHandler, ApplicationEventPublisherAware {
62+
implements WebSocketHandler, ApplicationEventPublisherAware, SmartLifecycle {
6263

6364
private static final Log logger = LogFactory.getLog(WampWebSocketHandler.class);
6465

66+
private static final String WAMP_SESSION_ID = "wamp2spring.session.id";
67+
private static final String WAMP_PRINCIPAL = "wamp2spring.principal";
68+
6569
public static final String JSON_PROTOCOL = "wamp.2.json";
6670

6771
public static final String MSGPACK_PROTOCOL = "wamp.2.msgpack";
@@ -83,14 +87,16 @@ public class WampWebSocketHandler
8387

8488
private final List<WampRole> roles;
8589

86-
private final ConcurrentMap<String, Long> webSocketId2WampSessionId = new ConcurrentHashMap<>();
87-
8890
private final MessageChannel clientInboundChannel;
8991

9092
private final MessageChannel clientOutboundChannel;
9193

9294
private ApplicationEventPublisher applicationEventPublisher;
9395

96+
private volatile boolean isRunning;
97+
98+
private final Set<WebSocketSession> webSocketSessions = ConcurrentHashMap.newKeySet();
99+
94100
public WampWebSocketHandler(JsonFactory jsonFactory, JsonFactory msgpackFactory,
95101
JsonFactory cborFactory, JsonFactory smileFactory,
96102
MessageChannel clientOutboundChannel, MessageChannel clientInboundChannel,
@@ -128,32 +134,56 @@ public List<String> getSubProtocols() {
128134

129135
@Override
130136
public Mono<Void> handle(WebSocketSession session) {
131-
return session.getHandshakeInfo().getPrincipal()
132-
.map(Optional::of).defaultIfEmpty(Optional.empty())
133-
.flatMap(optPrincipal -> {
134-
Principal principal = optPrincipal.orElse(null);
135-
136-
Mono<Void> receiveFlux = session.receive()
137-
.doOnNext(inMsg -> handleIncomingMessage(inMsg, session, principal))
138-
.then();
139-
140-
Mono<Void> sendFlux = session.send(Flux.from(MessageChannelReactiveUtils.toPublisher(this.clientOutboundChannel))
141-
.filter(msg -> resolveSessionId(msg).equals(session.getId()))
142-
.map(msg -> handleOutgoingMessage(msg, session))
143-
).doFinally(sig -> {
144-
Long wampSessionId = this.webSocketId2WampSessionId.get(session.getId());
145-
if (wampSessionId != null) {
146-
this.applicationEventPublisher.publishEvent(new WampDisconnectEvent(wampSessionId, session.getId(), principal));
147-
this.webSocketId2WampSessionId.remove(session.getId());
148-
}
149-
});
150-
151-
return Mono.when(receiveFlux, sendFlux);
152-
});
137+
if (!this.isRunning) {
138+
return session.close(CloseStatus.GOING_AWAY);
139+
}
140+
141+
webSocketSessions.add(session);
142+
143+
return Mono.when(
144+
session.getHandshakeInfo().getPrincipal().doOnNext(p -> session.getAttributes().put(WAMP_PRINCIPAL, p)),
145+
session.send(Flux.from(MessageChannelReactiveUtils.toPublisher(this.clientOutboundChannel))
146+
.filter(msg -> resolveSessionId(msg).equals(session.getId()))
147+
.map(msg -> handleOutgoingMessage(msg, session))
148+
),
149+
session.receive().doOnNext(inMsg -> handleIncomingMessage(inMsg, session))
150+
).doFinally(sig -> {
151+
webSocketSessions.remove(session);
152+
153+
Long wampSessionId = (Long) session.getAttributes().get(WAMP_SESSION_ID);
154+
Principal principalAttr = (Principal) session.getAttributes().get(WAMP_PRINCIPAL);
155+
156+
if (wampSessionId != null) {
157+
this.applicationEventPublisher.publishEvent(new WampDisconnectEvent(wampSessionId, session.getId(), principalAttr));
158+
}
159+
});
160+
}
161+
162+
@Override
163+
public void start() {
164+
if (!this.isRunning()) {
165+
this.isRunning = true;
166+
}
153167
}
154168

155-
public void handleIncomingMessage(WebSocketMessage inMsg, WebSocketSession session, Principal principal) {
169+
@Override
170+
public void stop() {
171+
if (this.isRunning()) {
172+
this.isRunning = false;
173+
174+
Flux.fromIterable(webSocketSessions)
175+
.flatMap(session -> session.close(CloseStatus.GOING_AWAY))
176+
.doFinally(sig -> webSocketSessions.clear())
177+
.subscribe();
178+
}
179+
}
156180

181+
@Override
182+
public boolean isRunning() {
183+
return this.isRunning;
184+
}
185+
186+
private void handleIncomingMessage(WebSocketMessage inMsg, WebSocketSession session) {
157187
try {
158188
WampMessage wampMessage = null;
159189

@@ -200,11 +230,13 @@ else if (WampWebSocketHandler.CBOR_PROTOCOL.equals(acceptedProtocol)) {
200230
return;
201231
}
202232

233+
Principal principal = (Principal) session.getAttributes().get(WAMP_PRINCIPAL);
234+
203235
wampMessage.setHeader(WampMessageHeader.WEBSOCKET_SESSION_ID,
204236
session.getId());
205237
wampMessage.setHeader(WampMessageHeader.PRINCIPAL, principal);
206238
wampMessage.setHeader(WampMessageHeader.WAMP_SESSION_ID,
207-
this.webSocketId2WampSessionId.get(session.getId()));
239+
session.getAttributes().get(WAMP_SESSION_ID));
208240

209241
if (wampMessage instanceof HelloMessage) {
210242
// If this is a helloMessage sent during a running session close the
@@ -215,8 +247,13 @@ else if (WampWebSocketHandler.CBOR_PROTOCOL.equals(acceptedProtocol)) {
215247
}
216248

217249
long newWampSessionId = IdGenerator.newRandomId(
218-
new HashSet<>(this.webSocketId2WampSessionId.values()));
219-
this.webSocketId2WampSessionId.put(session.getId(), newWampSessionId);
250+
webSocketSessions.stream()
251+
.map(webSocketSession -> (Long) webSocketSession.getAttributes().get(WAMP_SESSION_ID))
252+
.filter(Objects::nonNull)
253+
.collect(Collectors.toSet())
254+
);
255+
256+
session.getAttributes().put(WAMP_SESSION_ID, newWampSessionId);
220257

221258
WelcomeMessage welcomeMessage = new WelcomeMessage(
222259
(HelloMessage) wampMessage, newWampSessionId, this.roles);

0 commit comments

Comments
 (0)