001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.io.IOException;
020import java.nio.ByteBuffer;
021import java.security.Principal;
022import java.util.Arrays;
023import java.util.List;
024import java.util.Map;
025import java.util.Set;
026import java.util.concurrent.ConcurrentHashMap;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import org.apache.commons.logging.Log;
030import org.apache.commons.logging.LogFactory;
031
032import org.springframework.context.ApplicationEvent;
033import org.springframework.context.ApplicationEventPublisher;
034import org.springframework.context.ApplicationEventPublisherAware;
035import org.springframework.lang.Nullable;
036import org.springframework.messaging.Message;
037import org.springframework.messaging.MessageChannel;
038import org.springframework.messaging.simp.SimpAttributes;
039import org.springframework.messaging.simp.SimpAttributesContextHolder;
040import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
041import org.springframework.messaging.simp.SimpMessageType;
042import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
043import org.springframework.messaging.simp.stomp.StompCommand;
044import org.springframework.messaging.simp.stomp.StompDecoder;
045import org.springframework.messaging.simp.stomp.StompEncoder;
046import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
047import org.springframework.messaging.support.AbstractMessageChannel;
048import org.springframework.messaging.support.ChannelInterceptor;
049import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
050import org.springframework.messaging.support.MessageBuilder;
051import org.springframework.messaging.support.MessageHeaderAccessor;
052import org.springframework.messaging.support.MessageHeaderInitializer;
053import org.springframework.util.Assert;
054import org.springframework.util.MimeTypeUtils;
055import org.springframework.web.socket.BinaryMessage;
056import org.springframework.web.socket.CloseStatus;
057import org.springframework.web.socket.TextMessage;
058import org.springframework.web.socket.WebSocketMessage;
059import org.springframework.web.socket.WebSocketSession;
060import org.springframework.web.socket.handler.SessionLimitExceededException;
061import org.springframework.web.socket.handler.WebSocketSessionDecorator;
062import org.springframework.web.socket.sockjs.transport.SockJsSession;
063
064/**
065 * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2
066 * of the STOMP specification.
067 *
068 * @author Rossen Stoyanchev
069 * @author Andy Wilkinson
070 * @since 4.0
071 */
072public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
073
074        /**
075         * This handler supports assembling large STOMP messages split into multiple
076         * WebSocket messages and STOMP clients (like stomp.js) indeed split large STOMP
077         * messages at 16K boundaries. Therefore the WebSocket server input message
078         * buffer size must allow 16K at least plus a little extra for SockJS framing.
079         */
080        public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16 * 1024 + 256;
081
082        /**
083         * The name of the header set on the CONNECTED frame indicating the name
084         * of the user authenticated on the WebSocket session.
085         */
086        public static final String CONNECTED_USER_HEADER = "user-name";
087
088        private static final String[] SUPPORTED_VERSIONS = {"1.2", "1.1", "1.0"};
089
090        private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
091
092        private static final byte[] EMPTY_PAYLOAD = new byte[0];
093
094
095        @Nullable
096        private StompSubProtocolErrorHandler errorHandler;
097
098        private int messageSizeLimit = 64 * 1024;
099
100        private StompEncoder stompEncoder = new StompEncoder();
101
102        private StompDecoder stompDecoder = new StompDecoder();
103
104        private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap<>();
105
106        @Nullable
107        private MessageHeaderInitializer headerInitializer;
108
109        private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
110
111        @Nullable
112        private Boolean immutableMessageInterceptorPresent;
113
114        @Nullable
115        private ApplicationEventPublisher eventPublisher;
116
117        private final DefaultStats stats = new DefaultStats();
118
119
120        /**
121         * Configure a handler for error messages sent to clients which allows
122         * customizing the error messages or preventing them from being sent.
123         * <p>By default this isn't configured in which case an ERROR frame is sent
124         * with a message header reflecting the error.
125         * @param errorHandler the error handler
126         */
127        public void setErrorHandler(StompSubProtocolErrorHandler errorHandler) {
128                this.errorHandler = errorHandler;
129        }
130
131        /**
132         * Return the configured error handler.
133         */
134        @Nullable
135        public StompSubProtocolErrorHandler getErrorHandler() {
136                return this.errorHandler;
137        }
138
139        /**
140         * Configure the maximum size allowed for an incoming STOMP message.
141         * Since a STOMP message can be received in multiple WebSocket messages,
142         * buffering may be required and therefore it is necessary to know the maximum
143         * allowed message size.
144         * <p>By default this property is set to 64K.
145         * @since 4.0.3
146         */
147        public void setMessageSizeLimit(int messageSizeLimit) {
148                this.messageSizeLimit = messageSizeLimit;
149        }
150
151        /**
152         * Get the configured message buffer size limit in bytes.
153         * @since 4.0.3
154         */
155        public int getMessageSizeLimit() {
156                return this.messageSizeLimit;
157        }
158
159        /**
160         * Configure a {@link StompEncoder} for encoding STOMP frames.
161         * @since 4.3.5
162         */
163        public void setEncoder(StompEncoder encoder) {
164                this.stompEncoder = encoder;
165        }
166
167        /**
168         * Configure a {@link StompDecoder} for decoding STOMP frames.
169         * @since 4.3.5
170         */
171        public void setDecoder(StompDecoder decoder) {
172                this.stompDecoder = decoder;
173        }
174
175        /**
176         * Configure a {@link MessageHeaderInitializer} to apply to the headers of all
177         * messages created from decoded STOMP frames and other messages sent to the
178         * client inbound channel.
179         * <p>By default this property is not set.
180         */
181        public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) {
182                this.headerInitializer = headerInitializer;
183                this.stompDecoder.setHeaderInitializer(headerInitializer);
184        }
185
186        /**
187         * Return the configured header initializer.
188         */
189        @Nullable
190        public MessageHeaderInitializer getHeaderInitializer() {
191                return this.headerInitializer;
192        }
193
194        @Override
195        public List<String> getSupportedProtocols() {
196                return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
197        }
198
199        @Override
200        public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
201                this.eventPublisher = applicationEventPublisher;
202        }
203
204        /**
205         * Return a String describing internal state and counters.
206         * Effectively {@code toString()} on {@link #getStats() getStats()}.
207         */
208        public String getStatsInfo() {
209                return this.stats.toString();
210        }
211
212        /**
213         * Return a structured object with internal state and counters.
214         * @since 5.2
215         */
216        public Stats getStats() {
217                return this.stats;
218        }
219
220
221        /**
222         * Handle incoming WebSocket messages from clients.
223         */
224        @Override
225        public void handleMessageFromClient(WebSocketSession session,
226                        WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
227
228                List<Message<byte[]>> messages;
229                try {
230                        ByteBuffer byteBuffer;
231                        if (webSocketMessage instanceof TextMessage) {
232                                byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
233                        }
234                        else if (webSocketMessage instanceof BinaryMessage) {
235                                byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
236                        }
237                        else {
238                                return;
239                        }
240
241                        BufferingStompDecoder decoder = this.decoders.get(session.getId());
242                        if (decoder == null) {
243                                if (!session.isOpen()) {
244                                        logger.trace("Dropped inbound WebSocket message due to closed session");
245                                        return;
246                                }
247                                throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
248                        }
249
250                        messages = decoder.decode(byteBuffer);
251                        if (messages.isEmpty()) {
252                                if (logger.isTraceEnabled()) {
253                                        logger.trace("Incomplete STOMP frame content received in session " +
254                                                        session + ", bufferSize=" + decoder.getBufferSize() +
255                                                        ", bufferSizeLimit=" + decoder.getBufferSizeLimit() + ".");
256                                }
257                                return;
258                        }
259                }
260                catch (Throwable ex) {
261                        if (logger.isErrorEnabled()) {
262                                logger.error("Failed to parse " + webSocketMessage +
263                                                " in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
264                        }
265                        handleError(session, ex, null);
266                        return;
267                }
268
269                for (Message<byte[]> message : messages) {
270                        try {
271                                StompHeaderAccessor headerAccessor =
272                                                MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
273                                Assert.state(headerAccessor != null, "No StompHeaderAccessor");
274
275                                StompCommand command = headerAccessor.getCommand();
276                                boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
277
278                                headerAccessor.setSessionId(session.getId());
279                                headerAccessor.setSessionAttributes(session.getAttributes());
280                                headerAccessor.setUser(getUser(session));
281                                if (isConnect) {
282                                        headerAccessor.setUserChangeCallback(user -> {
283                                                if (user != null && user != session.getPrincipal()) {
284                                                        this.stompAuthentications.put(session.getId(), user);
285                                                }
286                                        });
287                                }
288                                headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
289                                if (!detectImmutableMessageInterceptor(outputChannel)) {
290                                        headerAccessor.setImmutable();
291                                }
292
293                                if (logger.isTraceEnabled()) {
294                                        logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
295                                }
296
297                                if (isConnect) {
298                                        this.stats.incrementConnectCount();
299                                }
300                                else if (StompCommand.DISCONNECT.equals(command)) {
301                                        this.stats.incrementDisconnectCount();
302                                }
303
304                                try {
305                                        SimpAttributesContextHolder.setAttributesFromMessage(message);
306                                        boolean sent = outputChannel.send(message);
307
308                                        if (sent) {
309                                                if (this.eventPublisher != null) {
310                                                        Principal user = getUser(session);
311                                                        if (isConnect) {
312                                                                publishEvent(this.eventPublisher, new SessionConnectEvent(this, message, user));
313                                                        }
314                                                        else if (StompCommand.SUBSCRIBE.equals(command)) {
315                                                                publishEvent(this.eventPublisher, new SessionSubscribeEvent(this, message, user));
316                                                        }
317                                                        else if (StompCommand.UNSUBSCRIBE.equals(command)) {
318                                                                publishEvent(this.eventPublisher, new SessionUnsubscribeEvent(this, message, user));
319                                                        }
320                                                }
321                                        }
322                                }
323                                finally {
324                                        SimpAttributesContextHolder.resetAttributes();
325                                }
326                        }
327                        catch (Throwable ex) {
328                                if (logger.isErrorEnabled()) {
329                                        String errorText = "Failed to send message to MessageChannel in session " + session.getId();
330                                        if (logger.isDebugEnabled()) {
331                                                logger.debug(errorText, ex);
332                                        }
333                                        else {
334                                                logger.error(errorText + ":" + ex.getMessage());
335                                        }
336                                }
337                                handleError(session, ex, message);
338                        }
339                }
340        }
341
342        @Nullable
343        private Principal getUser(WebSocketSession session) {
344                Principal user = this.stompAuthentications.get(session.getId());
345                return (user != null ? user : session.getPrincipal());
346        }
347
348        private void handleError(WebSocketSession session, Throwable ex, @Nullable Message<byte[]> clientMessage) {
349                if (getErrorHandler() == null) {
350                        sendErrorMessage(session, ex);
351                        return;
352                }
353                Message<byte[]> message = getErrorHandler().handleClientMessageProcessingError(clientMessage, ex);
354                if (message == null) {
355                        return;
356                }
357                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
358                Assert.state(accessor != null, "No StompHeaderAccessor");
359                sendToClient(session, accessor, message.getPayload());
360        }
361
362        /**
363         * Invoked when no
364         * {@link #setErrorHandler(StompSubProtocolErrorHandler) errorHandler}
365         * is configured to send an ERROR frame to the client.
366         */
367        private void sendErrorMessage(WebSocketSession session, Throwable error) {
368                StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
369                headerAccessor.setMessage(error.getMessage());
370
371                byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
372                try {
373                        session.sendMessage(new TextMessage(bytes));
374                }
375                catch (Throwable ex) {
376                        // Could be part of normal workflow (e.g. browser tab closed)
377                        logger.debug("Failed to send STOMP ERROR to client", ex);
378                }
379                finally {
380                        try {
381                                session.close(CloseStatus.PROTOCOL_ERROR);
382                        }
383                        catch (IOException ex) {
384                                // Ignore
385                        }
386                }
387        }
388
389        private boolean detectImmutableMessageInterceptor(MessageChannel channel) {
390                if (this.immutableMessageInterceptorPresent != null) {
391                        return this.immutableMessageInterceptorPresent;
392                }
393
394                if (channel instanceof AbstractMessageChannel) {
395                        for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) {
396                                if (interceptor instanceof ImmutableMessageChannelInterceptor) {
397                                        this.immutableMessageInterceptorPresent = true;
398                                        return true;
399                                }
400                        }
401                }
402                this.immutableMessageInterceptorPresent = false;
403                return false;
404        }
405
406        private void publishEvent(ApplicationEventPublisher publisher, ApplicationEvent event) {
407                try {
408                        publisher.publishEvent(event);
409                }
410                catch (Throwable ex) {
411                        if (logger.isErrorEnabled()) {
412                                logger.error("Error publishing " + event, ex);
413                        }
414                }
415        }
416
417        /**
418         * Handle STOMP messages going back out to WebSocket clients.
419         */
420        @Override
421        @SuppressWarnings("unchecked")
422        public void handleMessageToClient(WebSocketSession session, Message<?> message) {
423                if (!(message.getPayload() instanceof byte[])) {
424                        if (logger.isErrorEnabled()) {
425                                logger.error("Expected byte[] payload. Ignoring " + message + ".");
426                        }
427                        return;
428                }
429
430                StompHeaderAccessor accessor = getStompHeaderAccessor(message);
431                StompCommand command = accessor.getCommand();
432
433                if (StompCommand.MESSAGE.equals(command)) {
434                        if (accessor.getSubscriptionId() == null && logger.isWarnEnabled()) {
435                                logger.warn("No STOMP \"subscription\" header in " + message);
436                        }
437                        String origDestination = accessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
438                        if (origDestination != null) {
439                                accessor = toMutableAccessor(accessor, message);
440                                accessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
441                                accessor.setDestination(origDestination);
442                        }
443                }
444                else if (StompCommand.CONNECTED.equals(command)) {
445                        this.stats.incrementConnectedCount();
446                        accessor = afterStompSessionConnected(message, accessor, session);
447                        if (this.eventPublisher != null) {
448                                try {
449                                        SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
450                                        SimpAttributesContextHolder.setAttributes(simpAttributes);
451                                        Principal user = getUser(session);
452                                        publishEvent(this.eventPublisher, new SessionConnectedEvent(this, (Message<byte[]>) message, user));
453                                }
454                                finally {
455                                        SimpAttributesContextHolder.resetAttributes();
456                                }
457                        }
458                }
459
460                byte[] payload = (byte[]) message.getPayload();
461                if (StompCommand.ERROR.equals(command) && getErrorHandler() != null) {
462                        Message<byte[]> errorMessage = getErrorHandler().handleErrorMessageToClient((Message<byte[]>) message);
463                        if (errorMessage != null) {
464                                accessor = MessageHeaderAccessor.getAccessor(errorMessage, StompHeaderAccessor.class);
465                                Assert.state(accessor != null, "No StompHeaderAccessor");
466                                payload = errorMessage.getPayload();
467                        }
468                }
469                sendToClient(session, accessor, payload);
470        }
471
472        private void sendToClient(WebSocketSession session, StompHeaderAccessor stompAccessor, byte[] payload) {
473                StompCommand command = stompAccessor.getCommand();
474                try {
475                        byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), payload);
476                        boolean useBinary = (payload.length > 0 && !(session instanceof SockJsSession) &&
477                                        MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType()));
478                        if (useBinary) {
479                                session.sendMessage(new BinaryMessage(bytes));
480                        }
481                        else {
482                                session.sendMessage(new TextMessage(bytes));
483                        }
484                }
485                catch (SessionLimitExceededException ex) {
486                        // Bad session, just get out
487                        throw ex;
488                }
489                catch (Throwable ex) {
490                        // Could be part of normal workflow (e.g. browser tab closed)
491                        if (logger.isDebugEnabled()) {
492                                logger.debug("Failed to send WebSocket message to client in session " + session.getId(), ex);
493                        }
494                        command = StompCommand.ERROR;
495                }
496                finally {
497                        if (StompCommand.ERROR.equals(command)) {
498                                try {
499                                        session.close(CloseStatus.PROTOCOL_ERROR);
500                                }
501                                catch (IOException ex) {
502                                        // Ignore
503                                }
504                        }
505                }
506        }
507
508        private StompHeaderAccessor getStompHeaderAccessor(Message<?> message) {
509                MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
510                if (accessor instanceof StompHeaderAccessor) {
511                        return (StompHeaderAccessor) accessor;
512                }
513                else {
514                        StompHeaderAccessor stompAccessor = StompHeaderAccessor.wrap(message);
515                        SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
516                        if (SimpMessageType.CONNECT_ACK.equals(messageType)) {
517                                stompAccessor = convertConnectAcktoStompConnected(stompAccessor);
518                        }
519                        else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) {
520                                String receipt = getDisconnectReceipt(stompAccessor);
521                                if (receipt != null) {
522                                        stompAccessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
523                                        stompAccessor.setReceiptId(receipt);
524                                }
525                                else {
526                                        stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
527                                        stompAccessor.setMessage("Session closed.");
528                                }
529                        }
530                        else if (SimpMessageType.HEARTBEAT.equals(messageType)) {
531                                stompAccessor = StompHeaderAccessor.createForHeartbeat();
532                        }
533                        else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
534                                stompAccessor.updateStompCommandAsServerMessage();
535                        }
536                        return stompAccessor;
537                }
538        }
539
540        /**
541         * The simple broker produces {@code SimpMessageType.CONNECT_ACK} that's not STOMP
542         * specific and needs to be turned into a STOMP CONNECTED frame.
543         */
544        private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor connectAckHeaders) {
545                String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
546                Message<?> message = (Message<?>) connectAckHeaders.getHeader(name);
547                if (message == null) {
548                        throw new IllegalStateException("Original STOMP CONNECT not found in " + connectAckHeaders);
549                }
550
551                StompHeaderAccessor connectHeaders = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
552                StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
553
554                if (connectHeaders != null) {
555                        Set<String> acceptVersions = connectHeaders.getAcceptVersion();
556                        connectedHeaders.setVersion(
557                                        Arrays.stream(SUPPORTED_VERSIONS)
558                                                        .filter(acceptVersions::contains)
559                                                        .findAny()
560                                                        .orElseThrow(() -> new IllegalArgumentException(
561                                                                        "Unsupported STOMP version '" + acceptVersions + "'")));
562                }
563
564                long[] heartbeat = (long[]) connectAckHeaders.getHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER);
565                if (heartbeat != null) {
566                        connectedHeaders.setHeartbeat(heartbeat[0], heartbeat[1]);
567                }
568                else {
569                        connectedHeaders.setHeartbeat(0, 0);
570                }
571
572                return connectedHeaders;
573        }
574
575        @Nullable
576        private String getDisconnectReceipt(SimpMessageHeaderAccessor simpHeaders) {
577                String name = StompHeaderAccessor.DISCONNECT_MESSAGE_HEADER;
578                Message<?> message = (Message<?>) simpHeaders.getHeader(name);
579                if (message != null) {
580                        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
581                        if (accessor != null) {
582                                return accessor.getReceipt();
583                        }
584                }
585                return null;
586        }
587
588        protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
589                return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
590        }
591
592        private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor,
593                        WebSocketSession session) {
594
595                Principal principal = getUser(session);
596                if (principal != null) {
597                        accessor = toMutableAccessor(accessor, message);
598                        accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
599                }
600
601                long[] heartbeat = accessor.getHeartbeat();
602                if (heartbeat[1] > 0) {
603                        session = WebSocketSessionDecorator.unwrap(session);
604                        if (session instanceof SockJsSession) {
605                                ((SockJsSession) session).disableHeartbeat();
606                        }
607                }
608
609                return accessor;
610        }
611
612        @Override
613        @Nullable
614        public String resolveSessionId(Message<?> message) {
615                return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
616        }
617
618        @Override
619        public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) {
620                if (session.getTextMessageSizeLimit() < MINIMUM_WEBSOCKET_MESSAGE_SIZE) {
621                        session.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE);
622                }
623                this.decoders.put(session.getId(), new BufferingStompDecoder(this.stompDecoder, getMessageSizeLimit()));
624        }
625
626        @Override
627        public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
628                this.decoders.remove(session.getId());
629
630                Message<byte[]> message = createDisconnectMessage(session);
631                SimpAttributes simpAttributes = SimpAttributes.fromMessage(message);
632                try {
633                        SimpAttributesContextHolder.setAttributes(simpAttributes);
634                        if (this.eventPublisher != null) {
635                                Principal user = getUser(session);
636                                publishEvent(this.eventPublisher, new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user));
637                        }
638                        outputChannel.send(message);
639                }
640                finally {
641                        this.stompAuthentications.remove(session.getId());
642                        SimpAttributesContextHolder.resetAttributes();
643                        simpAttributes.sessionCompleted();
644                }
645        }
646
647        private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
648                StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
649                if (getHeaderInitializer() != null) {
650                        getHeaderInitializer().initHeaders(headerAccessor);
651                }
652
653                headerAccessor.setSessionId(session.getId());
654                headerAccessor.setSessionAttributes(session.getAttributes());
655
656                Principal user = getUser(session);
657                if (user != null) {
658                        headerAccessor.setUser(user);
659                }
660
661                return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
662        }
663
664
665        @Override
666        public String toString() {
667                return "StompSubProtocolHandler" + getSupportedProtocols();
668        }
669
670
671        /**
672         * Contract for access to session counters.
673         * @since 5.2
674         */
675        public interface Stats {
676
677                /**
678                 * The number of CONNECT frames processed.
679                 */
680                int getTotalConnect();
681
682                /**
683                 * The number of CONNECTED frames processed.
684                 */
685                int getTotalConnected();
686
687                /**
688                 * The number of DISCONNECT frames processed.
689                 */
690                int getTotalDisconnect();
691        }
692
693
694        private static class DefaultStats implements Stats {
695
696                private final AtomicInteger connect = new AtomicInteger();
697
698                private final AtomicInteger connected = new AtomicInteger();
699
700                private final AtomicInteger disconnect = new AtomicInteger();
701
702                public void incrementConnectCount() {
703                        this.connect.incrementAndGet();
704                }
705
706                public void incrementConnectedCount() {
707                        this.connected.incrementAndGet();
708                }
709
710                public void incrementDisconnectCount() {
711                        this.disconnect.incrementAndGet();
712                }
713
714                @Override
715                public int getTotalConnect() {
716                        return this.connect.get();
717                }
718
719                @Override
720                public int getTotalConnected() {
721                        return this.connected.get();
722                }
723
724                @Override
725                public int getTotalDisconnect() {
726                        return this.disconnect.get();
727                }
728
729                @Override
730                public String toString() {
731                        return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" +
732                                        this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + ")";
733                }
734        }
735
736}