001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.util.ArrayList;
020import java.util.Collections;
021import java.util.LinkedHashSet;
022import java.util.List;
023import java.util.Map;
024import java.util.Set;
025import java.util.TreeMap;
026import java.util.concurrent.ConcurrentHashMap;
027import java.util.concurrent.atomic.AtomicInteger;
028import java.util.concurrent.locks.ReentrantLock;
029
030import org.apache.commons.logging.Log;
031import org.apache.commons.logging.LogFactory;
032
033import org.springframework.context.SmartLifecycle;
034import org.springframework.lang.Nullable;
035import org.springframework.messaging.Message;
036import org.springframework.messaging.MessageChannel;
037import org.springframework.messaging.MessageHandler;
038import org.springframework.messaging.MessagingException;
039import org.springframework.messaging.SubscribableChannel;
040import org.springframework.util.Assert;
041import org.springframework.util.CollectionUtils;
042import org.springframework.util.StringUtils;
043import org.springframework.web.socket.CloseStatus;
044import org.springframework.web.socket.SubProtocolCapable;
045import org.springframework.web.socket.WebSocketHandler;
046import org.springframework.web.socket.WebSocketMessage;
047import org.springframework.web.socket.WebSocketSession;
048import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
049import org.springframework.web.socket.handler.SessionLimitExceededException;
050import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession;
051import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession;
052
053/**
054 * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
055 * messages to a {@link SubProtocolHandler} along with a {@link MessageChannel} to which
056 * the sub-protocol handler can send messages from WebSocket clients to the application.
057 *
058 * <p>Also an implementation of {@link MessageHandler} that finds the WebSocket session
059 * associated with the {@link Message} and passes it, along with the message, to the
060 * sub-protocol handler to send messages from the application back to the client.
061 *
062 * @author Rossen Stoyanchev
063 * @author Juergen Hoeller
064 * @author Andy Wilkinson
065 * @author Artem Bilan
066 * @since 4.0
067 */
068public class SubProtocolWebSocketHandler
069                implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
070
071        /** The default value for {@link #setTimeToFirstMessage(int) timeToFirstMessage}. */
072        private static final int DEFAULT_TIME_TO_FIRST_MESSAGE = 60 * 1000;
073
074
075        private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
076
077
078        private final MessageChannel clientInboundChannel;
079
080        private final SubscribableChannel clientOutboundChannel;
081
082        private final Map<String, SubProtocolHandler> protocolHandlerLookup =
083                        new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
084
085        private final Set<SubProtocolHandler> protocolHandlers = new LinkedHashSet<>();
086
087        @Nullable
088        private SubProtocolHandler defaultProtocolHandler;
089
090        private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<>();
091
092        private int sendTimeLimit = 10 * 1000;
093
094        private int sendBufferSizeLimit = 512 * 1024;
095
096        private int timeToFirstMessage = DEFAULT_TIME_TO_FIRST_MESSAGE;
097
098        private volatile long lastSessionCheckTime = System.currentTimeMillis();
099
100        private final ReentrantLock sessionCheckLock = new ReentrantLock();
101
102        private final DefaultStats stats = new DefaultStats();
103
104        private volatile boolean running = false;
105
106        private final Object lifecycleMonitor = new Object();
107
108
109        /**
110         * Create a new {@code SubProtocolWebSocketHandler} for the given inbound and outbound channels.
111         * @param clientInboundChannel the inbound {@code MessageChannel}
112         * @param clientOutboundChannel the outbound {@code MessageChannel}
113         */
114        public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
115                Assert.notNull(clientInboundChannel, "Inbound MessageChannel must not be null");
116                Assert.notNull(clientOutboundChannel, "Outbound MessageChannel must not be null");
117                this.clientInboundChannel = clientInboundChannel;
118                this.clientOutboundChannel = clientOutboundChannel;
119        }
120
121
122        /**
123         * Configure one or more handlers to use depending on the sub-protocol requested by
124         * the client in the WebSocket handshake request.
125         * @param protocolHandlers the sub-protocol handlers to use
126         */
127        public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) {
128                this.protocolHandlerLookup.clear();
129                this.protocolHandlers.clear();
130                for (SubProtocolHandler handler : protocolHandlers) {
131                        addProtocolHandler(handler);
132                }
133        }
134
135        public List<SubProtocolHandler> getProtocolHandlers() {
136                return new ArrayList<>(this.protocolHandlers);
137        }
138
139        /**
140         * Register a sub-protocol handler.
141         */
142        public void addProtocolHandler(SubProtocolHandler handler) {
143                List<String> protocols = handler.getSupportedProtocols();
144                if (CollectionUtils.isEmpty(protocols)) {
145                        if (logger.isErrorEnabled()) {
146                                logger.error("No sub-protocols for " + handler);
147                        }
148                        return;
149                }
150                for (String protocol : protocols) {
151                        SubProtocolHandler replaced = this.protocolHandlerLookup.put(protocol, handler);
152                        if (replaced != null && replaced != handler) {
153                                throw new IllegalStateException("Cannot map " + handler +
154                                                " to protocol '" + protocol + "': already mapped to " + replaced + ".");
155                        }
156                }
157                this.protocolHandlers.add(handler);
158        }
159
160        /**
161         * Return the sub-protocols keyed by protocol name.
162         */
163        public Map<String, SubProtocolHandler> getProtocolHandlerMap() {
164                return this.protocolHandlerLookup;
165        }
166
167        /**
168         * Set the {@link SubProtocolHandler} to use when the client did not request a
169         * sub-protocol.
170         * @param defaultProtocolHandler the default handler
171         */
172        public void setDefaultProtocolHandler(@Nullable SubProtocolHandler defaultProtocolHandler) {
173                this.defaultProtocolHandler = defaultProtocolHandler;
174                if (this.protocolHandlerLookup.isEmpty()) {
175                        setProtocolHandlers(Collections.singletonList(defaultProtocolHandler));
176                }
177        }
178
179        /**
180         * Return the default sub-protocol handler to use.
181         */
182        @Nullable
183        public SubProtocolHandler getDefaultProtocolHandler() {
184                return this.defaultProtocolHandler;
185        }
186
187        /**
188         * Return all supported protocols.
189         */
190        @Override
191        public List<String> getSubProtocols() {
192                return new ArrayList<>(this.protocolHandlerLookup.keySet());
193        }
194
195        /**
196         * Specify the send-time limit (milliseconds).
197         * @see ConcurrentWebSocketSessionDecorator
198         */
199        public void setSendTimeLimit(int sendTimeLimit) {
200                this.sendTimeLimit = sendTimeLimit;
201        }
202
203        /**
204         * Return the send-time limit (milliseconds).
205         */
206        public int getSendTimeLimit() {
207                return this.sendTimeLimit;
208        }
209
210        /**
211         * Specify the buffer-size limit (number of bytes).
212         * @see ConcurrentWebSocketSessionDecorator
213         */
214        public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
215                this.sendBufferSizeLimit = sendBufferSizeLimit;
216        }
217
218        /**
219         * Return the buffer-size limit (number of bytes).
220         */
221        public int getSendBufferSizeLimit() {
222                return this.sendBufferSizeLimit;
223        }
224
225        /**
226         * Set the maximum time allowed in milliseconds after the WebSocket connection
227         * is established and before the first sub-protocol message is received.
228         * <p>This handler is for WebSocket connections that use a sub-protocol.
229         * Therefore, we expect the client to send at least one sub-protocol message
230         * in the beginning, or else we assume the connection isn't doing well, e.g.
231         * proxy issue, slow network, and can be closed.
232         * <p>By default this is set to {@code 60,000} (1 minute).
233         * @param timeToFirstMessage the maximum time allowed in milliseconds
234         * @since 5.1
235         * @see #checkSessions()
236         */
237        public void setTimeToFirstMessage(int timeToFirstMessage) {
238                this.timeToFirstMessage = timeToFirstMessage;
239        }
240
241        /**
242         * Return the maximum time allowed after the WebSocket connection is
243         * established and before the first sub-protocol message.
244         * @since 5.1
245         */
246        public int getTimeToFirstMessage() {
247                return this.timeToFirstMessage;
248        }
249
250        /**
251         * Return a String describing internal state and counters.
252         * Effectively {@code toString()} on {@link #getStats() getStats()}.
253         */
254        public String getStatsInfo() {
255                return this.stats.toString();
256        }
257
258        /**
259         * Return a structured object with various session counters.
260         * @since 5.2
261         */
262        public Stats getStats() {
263                return this.stats;
264        }
265
266
267
268        @Override
269        public final void start() {
270                Assert.isTrue(this.defaultProtocolHandler != null || !this.protocolHandlers.isEmpty(), "No handlers");
271
272                synchronized (this.lifecycleMonitor) {
273                        this.clientOutboundChannel.subscribe(this);
274                        this.running = true;
275                }
276        }
277
278        @Override
279        public final void stop() {
280                synchronized (this.lifecycleMonitor) {
281                        this.running = false;
282                        this.clientOutboundChannel.unsubscribe(this);
283                }
284
285                // Proactively notify all active WebSocket sessions
286                for (WebSocketSessionHolder holder : this.sessions.values()) {
287                        try {
288                                holder.getSession().close(CloseStatus.GOING_AWAY);
289                        }
290                        catch (Throwable ex) {
291                                if (logger.isWarnEnabled()) {
292                                        logger.warn("Failed to close '" + holder.getSession() + "': " + ex);
293                                }
294                        }
295                }
296        }
297
298        @Override
299        public final void stop(Runnable callback) {
300                synchronized (this.lifecycleMonitor) {
301                        stop();
302                        callback.run();
303                }
304        }
305
306        @Override
307        public final boolean isRunning() {
308                return this.running;
309        }
310
311
312        @Override
313        public void afterConnectionEstablished(WebSocketSession session) throws Exception {
314                // WebSocketHandlerDecorator could close the session
315                if (!session.isOpen()) {
316                        return;
317                }
318
319                this.stats.incrementSessionCount(session);
320                session = decorateSession(session);
321                this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
322                findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
323        }
324
325        /**
326         * Handle an inbound message from a WebSocket client.
327         */
328        @Override
329        public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
330                WebSocketSessionHolder holder = this.sessions.get(session.getId());
331                if (holder != null) {
332                        session = holder.getSession();
333                }
334                SubProtocolHandler protocolHandler = findProtocolHandler(session);
335                protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
336                if (holder != null) {
337                        holder.setHasHandledMessages();
338                }
339                checkSessions();
340        }
341
342        /**
343         * Handle an outbound Spring Message to a WebSocket client.
344         */
345        @Override
346        public void handleMessage(Message<?> message) throws MessagingException {
347                String sessionId = resolveSessionId(message);
348                if (sessionId == null) {
349                        if (logger.isErrorEnabled()) {
350                                logger.error("Could not find session id in " + message);
351                        }
352                        return;
353                }
354
355                WebSocketSessionHolder holder = this.sessions.get(sessionId);
356                if (holder == null) {
357                        if (logger.isDebugEnabled()) {
358                                // The broker may not have removed the session yet
359                                logger.debug("No session for " + message);
360                        }
361                        return;
362                }
363
364                WebSocketSession session = holder.getSession();
365                try {
366                        findProtocolHandler(session).handleMessageToClient(session, message);
367                }
368                catch (SessionLimitExceededException ex) {
369                        try {
370                                if (logger.isDebugEnabled()) {
371                                        logger.debug("Terminating '" + session + "'", ex);
372                                }
373                                else if (logger.isWarnEnabled()) {
374                                        logger.warn("Terminating '" + session + "': " + ex.getMessage());
375                                }
376                                this.stats.incrementLimitExceededCount();
377                                clearSession(session, ex.getStatus()); // clear first, session may be unresponsive
378                                session.close(ex.getStatus());
379                        }
380                        catch (Exception secondException) {
381                                logger.debug("Failure while closing session " + sessionId + ".", secondException);
382                        }
383                }
384                catch (Exception ex) {
385                        // Could be part of normal workflow (e.g. browser tab closed)
386                        if (logger.isDebugEnabled()) {
387                                logger.debug("Failed to send message to client in " + session + ": " + message, ex);
388                        }
389                }
390        }
391
392        @Override
393        public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
394                this.stats.incrementTransportError();
395        }
396
397        @Override
398        public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
399                clearSession(session, closeStatus);
400        }
401
402        @Override
403        public boolean supportsPartialMessages() {
404                return false;
405        }
406
407
408        /**
409         * Decorate the given {@link WebSocketSession}, if desired.
410         * <p>The default implementation builds a {@link ConcurrentWebSocketSessionDecorator}
411         * with the configured {@link #getSendTimeLimit() send-time limit} and
412         * {@link #getSendBufferSizeLimit() buffer-size limit}.
413         * @param session the original {@code WebSocketSession}
414         * @return the decorated {@code WebSocketSession}, or potentially the given session as-is
415         * @since 4.3.13
416         */
417        protected WebSocketSession decorateSession(WebSocketSession session) {
418                return new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
419        }
420
421        /**
422         * Find a {@link SubProtocolHandler} for the given session.
423         * @param session the {@code WebSocketSession} to find a handler for
424         */
425        protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
426                String protocol = null;
427                try {
428                        protocol = session.getAcceptedProtocol();
429                }
430                catch (Exception ex) {
431                        // Shouldn't happen
432                        logger.error("Failed to obtain session.getAcceptedProtocol(): " +
433                                        "will use the default protocol handler (if configured).", ex);
434                }
435
436                SubProtocolHandler handler;
437                if (StringUtils.hasLength(protocol)) {
438                        handler = this.protocolHandlerLookup.get(protocol);
439                        if (handler == null) {
440                                throw new IllegalStateException(
441                                                "No handler for '" + protocol + "' among " + this.protocolHandlerLookup);
442                        }
443                }
444                else {
445                        if (this.defaultProtocolHandler != null) {
446                                handler = this.defaultProtocolHandler;
447                        }
448                        else if (this.protocolHandlers.size() == 1) {
449                                handler = this.protocolHandlers.iterator().next();
450                        }
451                        else {
452                                throw new IllegalStateException("Multiple protocol handlers configured and " +
453                                                "no protocol was negotiated. Consider configuring a default SubProtocolHandler.");
454                        }
455                }
456                return handler;
457        }
458
459        @Nullable
460        private String resolveSessionId(Message<?> message) {
461                for (SubProtocolHandler handler : this.protocolHandlerLookup.values()) {
462                        String sessionId = handler.resolveSessionId(message);
463                        if (sessionId != null) {
464                                return sessionId;
465                        }
466                }
467                if (this.defaultProtocolHandler != null) {
468                        String sessionId = this.defaultProtocolHandler.resolveSessionId(message);
469                        if (sessionId != null) {
470                                return sessionId;
471                        }
472                }
473                return null;
474        }
475
476        /**
477         * When a session is connected through a higher-level protocol it has a chance
478         * to use heartbeat management to shut down sessions that are too slow to send
479         * or receive messages. However, after a WebSocketSession is established and
480         * before the higher level protocol is fully connected there is a possibility for
481         * sessions to hang. This method checks and closes any sessions that have been
482         * connected for more than 60 seconds without having received a single message.
483         */
484        private void checkSessions() {
485                long currentTime = System.currentTimeMillis();
486                if (!isRunning() || (currentTime - this.lastSessionCheckTime < getTimeToFirstMessage())) {
487                        return;
488                }
489
490                if (this.sessionCheckLock.tryLock()) {
491                        try {
492                                for (WebSocketSessionHolder holder : this.sessions.values()) {
493                                        if (holder.hasHandledMessages()) {
494                                                continue;
495                                        }
496                                        long timeSinceCreated = currentTime - holder.getCreateTime();
497                                        if (timeSinceCreated < getTimeToFirstMessage()) {
498                                                continue;
499                                        }
500                                        WebSocketSession session = holder.getSession();
501                                        if (logger.isInfoEnabled()) {
502                                                logger.info("No messages received after " + timeSinceCreated + " ms. " +
503                                                                "Closing " + holder.getSession() + ".");
504                                        }
505                                        try {
506                                                this.stats.incrementNoMessagesReceivedCount();
507                                                session.close(CloseStatus.SESSION_NOT_RELIABLE);
508                                        }
509                                        catch (Throwable ex) {
510                                                if (logger.isWarnEnabled()) {
511                                                        logger.warn("Failed to close unreliable " + session, ex);
512                                                }
513                                        }
514                                }
515                        }
516                        finally {
517                                this.lastSessionCheckTime = currentTime;
518                                this.sessionCheckLock.unlock();
519                        }
520                }
521        }
522
523        private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
524                if (logger.isDebugEnabled()) {
525                        logger.debug("Clearing session " + session.getId());
526                }
527                if (this.sessions.remove(session.getId()) != null) {
528                        this.stats.decrementSessionCount(session);
529                }
530                findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
531        }
532
533
534        @Override
535        public String toString() {
536                return "SubProtocolWebSocketHandler" + this.protocolHandlers;
537        }
538
539
540        private static class WebSocketSessionHolder {
541
542                private final WebSocketSession session;
543
544                private final long createTime;
545
546                private volatile boolean hasHandledMessages;
547
548                public WebSocketSessionHolder(WebSocketSession session) {
549                        this.session = session;
550                        this.createTime = System.currentTimeMillis();
551                }
552
553                public WebSocketSession getSession() {
554                        return this.session;
555                }
556
557                public long getCreateTime() {
558                        return this.createTime;
559                }
560
561                public void setHasHandledMessages() {
562                        this.hasHandledMessages = true;
563                }
564
565                public boolean hasHandledMessages() {
566                        return this.hasHandledMessages;
567                }
568
569                @Override
570                public String toString() {
571                        return "WebSocketSessionHolder[session=" + this.session + ", createTime=" +
572                                        this.createTime + ", hasHandledMessages=" + this.hasHandledMessages + "]";
573                }
574        }
575
576
577        /**
578         * Contract for access to session counters.
579         * @since 5.2
580         */
581        public interface Stats {
582
583                int getTotalSessions();
584
585                int getWebSocketSessions();
586
587                int getHttpStreamingSessions();
588
589                int getHttpPollingSessions();
590
591                int getLimitExceededSessions();
592
593                int getNoMessagesReceivedSessions();
594
595                int getTransportErrorSessions();
596        }
597
598
599        private class DefaultStats implements Stats {
600
601                private final AtomicInteger total = new AtomicInteger();
602
603                private final AtomicInteger webSocket = new AtomicInteger();
604
605                private final AtomicInteger httpStreaming = new AtomicInteger();
606
607                private final AtomicInteger httpPolling = new AtomicInteger();
608
609                private final AtomicInteger limitExceeded = new AtomicInteger();
610
611                private final AtomicInteger noMessagesReceived = new AtomicInteger();
612
613                private final AtomicInteger transportError = new AtomicInteger();
614
615                @Override
616                public int getTotalSessions() {
617                        return this.total.get();
618                }
619
620                @Override
621                public int getWebSocketSessions() {
622                        return this.webSocket.get();
623                }
624
625                @Override
626                public int getHttpStreamingSessions() {
627                        return this.httpStreaming.get();
628                }
629
630                @Override
631                public int getHttpPollingSessions() {
632                        return this.httpPolling.get();
633                }
634
635                @Override
636                public int getLimitExceededSessions() {
637                        return this.limitExceeded.get();
638                }
639
640                @Override
641                public int getNoMessagesReceivedSessions() {
642                        return this.noMessagesReceived.get();
643                }
644
645                @Override
646                public int getTransportErrorSessions() {
647                        return this.transportError.get();
648                }
649
650                void incrementSessionCount(WebSocketSession session) {
651                        getCountFor(session).incrementAndGet();
652                        this.total.incrementAndGet();
653                }
654
655                void decrementSessionCount(WebSocketSession session) {
656                        getCountFor(session).decrementAndGet();
657                }
658
659                void incrementLimitExceededCount() {
660                        this.limitExceeded.incrementAndGet();
661                }
662
663                void incrementNoMessagesReceivedCount() {
664                        this.noMessagesReceived.incrementAndGet();
665                }
666
667                void incrementTransportError() {
668                        this.transportError.incrementAndGet();
669                }
670
671                AtomicInteger getCountFor(WebSocketSession session) {
672                        if (session instanceof PollingSockJsSession) {
673                                return this.httpPolling;
674                        }
675                        else if (session instanceof StreamingSockJsSession) {
676                                return this.httpStreaming;
677                        }
678                        else {
679                                return this.webSocket;
680                        }
681                }
682
683                @Override
684                public String toString() {
685                        return SubProtocolWebSocketHandler.this.sessions.size() +
686                                        " current WS(" + this.webSocket.get() +
687                                        ")-HttpStream(" + this.httpStreaming.get() +
688                                        ")-HttpPoll(" + this.httpPolling.get() + "), " +
689                                        this.total.get() + " total, " +
690                                        (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" +
691                                        this.noMessagesReceived.get() + " connect failure, " +
692                                        this.limitExceeded.get() + " send limit, " +
693                                        this.transportError.get() + " transport error)";
694                }
695        }
696
697}