001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.util.ArrayList;
020import java.util.Collections;
021import java.util.LinkedHashSet;
022import java.util.List;
023import java.util.Map;
024import java.util.Set;
025import java.util.TreeMap;
026import java.util.concurrent.ConcurrentHashMap;
027import java.util.concurrent.atomic.AtomicInteger;
028import java.util.concurrent.locks.ReentrantLock;
029
030import org.apache.commons.logging.Log;
031import org.apache.commons.logging.LogFactory;
032
033import org.springframework.context.SmartLifecycle;
034import org.springframework.messaging.Message;
035import org.springframework.messaging.MessageChannel;
036import org.springframework.messaging.MessageHandler;
037import org.springframework.messaging.MessagingException;
038import org.springframework.messaging.SubscribableChannel;
039import org.springframework.util.Assert;
040import org.springframework.util.CollectionUtils;
041import org.springframework.util.StringUtils;
042import org.springframework.web.socket.CloseStatus;
043import org.springframework.web.socket.SubProtocolCapable;
044import org.springframework.web.socket.WebSocketHandler;
045import org.springframework.web.socket.WebSocketMessage;
046import org.springframework.web.socket.WebSocketSession;
047import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
048import org.springframework.web.socket.handler.SessionLimitExceededException;
049import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession;
050import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession;
051
052/**
053 * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
054 * messages to a {@link SubProtocolHandler} along with a {@link MessageChannel} to which
055 * the sub-protocol handler can send messages from WebSocket clients to the application.
056 *
057 * <p>Also an implementation of {@link MessageHandler} that finds the WebSocket session
058 * associated with the {@link Message} and passes it, along with the message, to the
059 * sub-protocol handler to send messages from the application back to the client.
060 *
061 * @author Rossen Stoyanchev
062 * @author Juergen Hoeller
063 * @author Andy Wilkinson
064 * @author Artem Bilan
065 * @since 4.0
066 */
067public class SubProtocolWebSocketHandler
068                implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
069
070        /**
071         * Sessions connected to this handler use a sub-protocol. Hence we expect to
072         * receive some client messages. If we don't receive any within a minute, the
073         * connection isn't doing well (proxy issue, slow network?) and can be closed.
074         * @see #checkSessions()
075         */
076        private static final int TIME_TO_FIRST_MESSAGE = 60 * 1000;
077
078
079        private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
080
081
082        private final MessageChannel clientInboundChannel;
083
084        private final SubscribableChannel clientOutboundChannel;
085
086        private final Map<String, SubProtocolHandler> protocolHandlerLookup =
087                        new TreeMap<String, SubProtocolHandler>(String.CASE_INSENSITIVE_ORDER);
088
089        private final Set<SubProtocolHandler> protocolHandlers = new LinkedHashSet<SubProtocolHandler>();
090
091        private SubProtocolHandler defaultProtocolHandler;
092
093        private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
094
095        private int sendTimeLimit = 10 * 1000;
096
097        private int sendBufferSizeLimit = 512 * 1024;
098
099        private volatile long lastSessionCheckTime = System.currentTimeMillis();
100
101        private final ReentrantLock sessionCheckLock = new ReentrantLock();
102
103        private final Stats stats = new Stats();
104
105        private volatile boolean running = false;
106
107        private final Object lifecycleMonitor = new Object();
108
109
110        /**
111         * Create a new {@code SubProtocolWebSocketHandler} for the given inbound and outbound channels.
112         * @param clientInboundChannel the inbound {@code MessageChannel}
113         * @param clientOutboundChannel the outbound {@code MessageChannel}
114         */
115        public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
116                Assert.notNull(clientInboundChannel, "Inbound MessageChannel must not be null");
117                Assert.notNull(clientOutboundChannel, "Outbound MessageChannel must not be null");
118                this.clientInboundChannel = clientInboundChannel;
119                this.clientOutboundChannel = clientOutboundChannel;
120        }
121
122
123        /**
124         * Configure one or more handlers to use depending on the sub-protocol requested by
125         * the client in the WebSocket handshake request.
126         * @param protocolHandlers the sub-protocol handlers to use
127         */
128        public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) {
129                this.protocolHandlerLookup.clear();
130                this.protocolHandlers.clear();
131                for (SubProtocolHandler handler : protocolHandlers) {
132                        addProtocolHandler(handler);
133                }
134        }
135
136        public List<SubProtocolHandler> getProtocolHandlers() {
137                return new ArrayList<SubProtocolHandler>(this.protocolHandlers);
138        }
139
140        /**
141         * Register a sub-protocol handler.
142         */
143        public void addProtocolHandler(SubProtocolHandler handler) {
144                List<String> protocols = handler.getSupportedProtocols();
145                if (CollectionUtils.isEmpty(protocols)) {
146                        if (logger.isErrorEnabled()) {
147                                logger.error("No sub-protocols for " + handler);
148                        }
149                        return;
150                }
151                for (String protocol : protocols) {
152                        SubProtocolHandler replaced = this.protocolHandlerLookup.put(protocol, handler);
153                        if (replaced != null && replaced != handler) {
154                                throw new IllegalStateException("Cannot map " + handler +
155                                                " to protocol '" + protocol + "': already mapped to " + replaced + ".");
156                        }
157                }
158                this.protocolHandlers.add(handler);
159        }
160
161        /**
162         * Return the sub-protocols keyed by protocol name.
163         */
164        public Map<String, SubProtocolHandler> getProtocolHandlerMap() {
165                return this.protocolHandlerLookup;
166        }
167
168        /**
169         * Set the {@link SubProtocolHandler} to use when the client did not request a
170         * sub-protocol.
171         * @param defaultProtocolHandler the default handler
172         */
173        public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) {
174                this.defaultProtocolHandler = defaultProtocolHandler;
175                if (this.protocolHandlerLookup.isEmpty()) {
176                        setProtocolHandlers(Collections.singletonList(defaultProtocolHandler));
177                }
178        }
179
180        /**
181         * Return the default sub-protocol handler to use.
182         */
183        public SubProtocolHandler getDefaultProtocolHandler() {
184                return this.defaultProtocolHandler;
185        }
186
187        /**
188         * Return all supported protocols.
189         */
190        public List<String> getSubProtocols() {
191                return new ArrayList<String>(this.protocolHandlerLookup.keySet());
192        }
193
194        /**
195         * Specify the send-time limit (milliseconds).
196         * @see ConcurrentWebSocketSessionDecorator
197         */
198        public void setSendTimeLimit(int sendTimeLimit) {
199                this.sendTimeLimit = sendTimeLimit;
200        }
201
202        /**
203         * Return the send-time limit (milliseconds).
204         */
205        public int getSendTimeLimit() {
206                return this.sendTimeLimit;
207        }
208
209        /**
210         * Specify the buffer-size limit (number of bytes).
211         * @see ConcurrentWebSocketSessionDecorator
212         */
213        public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
214                this.sendBufferSizeLimit = sendBufferSizeLimit;
215        }
216
217        /**
218         * Return the buffer-size limit (number of bytes).
219         */
220        public int getSendBufferSizeLimit() {
221                return this.sendBufferSizeLimit;
222        }
223
224        /**
225         * Return a String describing internal state and counters.
226         */
227        public String getStatsInfo() {
228                return this.stats.toString();
229        }
230
231
232        @Override
233        public boolean isAutoStartup() {
234                return true;
235        }
236
237        @Override
238        public int getPhase() {
239                return Integer.MAX_VALUE;
240        }
241
242        @Override
243        public final void start() {
244                Assert.isTrue(this.defaultProtocolHandler != null || !this.protocolHandlers.isEmpty(), "No handlers");
245
246                synchronized (this.lifecycleMonitor) {
247                        this.clientOutboundChannel.subscribe(this);
248                        this.running = true;
249                }
250        }
251
252        @Override
253        public final void stop() {
254                synchronized (this.lifecycleMonitor) {
255                        this.running = false;
256                        this.clientOutboundChannel.unsubscribe(this);
257                }
258
259                // Proactively notify all active WebSocket sessions
260                for (WebSocketSessionHolder holder : this.sessions.values()) {
261                        try {
262                                holder.getSession().close(CloseStatus.GOING_AWAY);
263                        }
264                        catch (Throwable ex) {
265                                if (logger.isWarnEnabled()) {
266                                        logger.warn("Failed to close '" + holder.getSession() + "': " + ex);
267                                }
268                        }
269                }
270        }
271
272        @Override
273        public final void stop(Runnable callback) {
274                synchronized (this.lifecycleMonitor) {
275                        stop();
276                        callback.run();
277                }
278        }
279
280        @Override
281        public final boolean isRunning() {
282                return this.running;
283        }
284
285
286        @Override
287        public void afterConnectionEstablished(WebSocketSession session) throws Exception {
288                // WebSocketHandlerDecorator could close the session
289                if (!session.isOpen()) {
290                        return;
291                }
292
293                this.stats.incrementSessionCount(session);
294                session = decorateSession(session);
295                this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
296                findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
297        }
298
299        /**
300         * Handle an inbound message from a WebSocket client.
301         */
302        @Override
303        public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
304                WebSocketSessionHolder holder = this.sessions.get(session.getId());
305                if (holder != null) {
306                        session = holder.getSession();
307                }
308                SubProtocolHandler protocolHandler = findProtocolHandler(session);
309                protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
310                if (holder != null) {
311                        holder.setHasHandledMessages();
312                }
313                checkSessions();
314        }
315
316        /**
317         * Handle an outbound Spring Message to a WebSocket client.
318         */
319        @Override
320        public void handleMessage(Message<?> message) throws MessagingException {
321                String sessionId = resolveSessionId(message);
322                if (sessionId == null) {
323                        if (logger.isErrorEnabled()) {
324                                logger.error("Could not find session id in " + message);
325                        }
326                        return;
327                }
328
329                WebSocketSessionHolder holder = this.sessions.get(sessionId);
330                if (holder == null) {
331                        if (logger.isDebugEnabled()) {
332                                // The broker may not have removed the session yet
333                                logger.debug("No session for " + message);
334                        }
335                        return;
336                }
337
338                WebSocketSession session = holder.getSession();
339                try {
340                        findProtocolHandler(session).handleMessageToClient(session, message);
341                }
342                catch (SessionLimitExceededException ex) {
343                        try {
344                                if (logger.isDebugEnabled()) {
345                                        logger.debug("Terminating '" + session + "'", ex);
346                                }
347                                this.stats.incrementLimitExceededCount();
348                                clearSession(session, ex.getStatus()); // clear first, session may be unresponsive
349                                session.close(ex.getStatus());
350                        }
351                        catch (Exception secondException) {
352                                logger.debug("Failure while closing session " + sessionId + ".", secondException);
353                        }
354                }
355                catch (Exception ex) {
356                        // Could be part of normal workflow (e.g. browser tab closed)
357                        if (logger.isDebugEnabled()) {
358                                logger.debug("Failed to send message to client in " + session + ": " + message, ex);
359                        }
360                }
361        }
362
363        @Override
364        public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
365                this.stats.incrementTransportError();
366        }
367
368        @Override
369        public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
370                clearSession(session, closeStatus);
371        }
372
373        @Override
374        public boolean supportsPartialMessages() {
375                return false;
376        }
377
378
379        /**
380         * Decorate the given {@link WebSocketSession}, if desired.
381         * <p>The default implementation builds a {@link ConcurrentWebSocketSessionDecorator}
382         * with the configured {@link #getSendTimeLimit() send-time limit} and
383         * {@link #getSendBufferSizeLimit() buffer-size limit}.
384         * @param session the original {@code WebSocketSession}
385         * @return the decorated {@code WebSocketSession}, or potentially the given session as-is
386         * @since 4.3.13
387         */
388        protected WebSocketSession decorateSession(WebSocketSession session) {
389                return new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
390        }
391
392        /**
393         * Find a {@link SubProtocolHandler} for the given session.
394         * @param session the {@code WebSocketSession} to find a handler for
395         */
396        protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
397                String protocol = null;
398                try {
399                        protocol = session.getAcceptedProtocol();
400                }
401                catch (Exception ex) {
402                        // Shouldn't happen
403                        logger.error("Failed to obtain session.getAcceptedProtocol(): " +
404                                        "will use the default protocol handler (if configured).", ex);
405                }
406
407                SubProtocolHandler handler;
408                if (!StringUtils.isEmpty(protocol)) {
409                        handler = this.protocolHandlerLookup.get(protocol);
410                        if (handler == null) {
411                                throw new IllegalStateException(
412                                                "No handler for '" + protocol + "' among " + this.protocolHandlerLookup);
413                        }
414                }
415                else {
416                        if (this.defaultProtocolHandler != null) {
417                                handler = this.defaultProtocolHandler;
418                        }
419                        else if (this.protocolHandlers.size() == 1) {
420                                handler = this.protocolHandlers.iterator().next();
421                        }
422                        else {
423                                throw new IllegalStateException("Multiple protocol handlers configured and " +
424                                                "no protocol was negotiated. Consider configuring a default SubProtocolHandler.");
425                        }
426                }
427                return handler;
428        }
429
430        private String resolveSessionId(Message<?> message) {
431                for (SubProtocolHandler handler : this.protocolHandlerLookup.values()) {
432                        String sessionId = handler.resolveSessionId(message);
433                        if (sessionId != null) {
434                                return sessionId;
435                        }
436                }
437                if (this.defaultProtocolHandler != null) {
438                        String sessionId = this.defaultProtocolHandler.resolveSessionId(message);
439                        if (sessionId != null) {
440                                return sessionId;
441                        }
442                }
443                return null;
444        }
445
446        /**
447         * When a session is connected through a higher-level protocol it has a chance
448         * to use heartbeat management to shut down sessions that are too slow to send
449         * or receive messages. However, after a WebSocketSession is established and
450         * before the higher level protocol is fully connected there is a possibility for
451         * sessions to hang. This method checks and closes any sessions that have been
452         * connected for more than 60 seconds without having received a single message.
453         */
454        private void checkSessions() {
455                long currentTime = System.currentTimeMillis();
456                if (!isRunning() || (currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE)) {
457                        return;
458                }
459
460                if (this.sessionCheckLock.tryLock()) {
461                        try {
462                                for (WebSocketSessionHolder holder : this.sessions.values()) {
463                                        if (holder.hasHandledMessages()) {
464                                                continue;
465                                        }
466                                        long timeSinceCreated = currentTime - holder.getCreateTime();
467                                        if (timeSinceCreated < TIME_TO_FIRST_MESSAGE) {
468                                                continue;
469                                        }
470                                        WebSocketSession session = holder.getSession();
471                                        if (logger.isInfoEnabled()) {
472                                                logger.info("No messages received after " + timeSinceCreated + " ms. " +
473                                                                "Closing " + holder.getSession() + ".");
474                                        }
475                                        try {
476                                                this.stats.incrementNoMessagesReceivedCount();
477                                                session.close(CloseStatus.SESSION_NOT_RELIABLE);
478                                        }
479                                        catch (Throwable ex) {
480                                                if (logger.isWarnEnabled()) {
481                                                        logger.warn("Failed to close unreliable " + session, ex);
482                                                }
483                                        }
484                                }
485                        }
486                        finally {
487                                this.lastSessionCheckTime = currentTime;
488                                this.sessionCheckLock.unlock();
489                        }
490                }
491        }
492
493        private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
494                if (logger.isDebugEnabled()) {
495                        logger.debug("Clearing session " + session.getId());
496                }
497                if (this.sessions.remove(session.getId()) != null) {
498                        this.stats.decrementSessionCount(session);
499                }
500                findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
501        }
502
503
504        @Override
505        public String toString() {
506                return "SubProtocolWebSocketHandler" + this.protocolHandlers;
507        }
508
509
510        private static class WebSocketSessionHolder {
511
512                private final WebSocketSession session;
513
514                private final long createTime;
515
516                private volatile boolean hasHandledMessages;
517
518                public WebSocketSessionHolder(WebSocketSession session) {
519                        this.session = session;
520                        this.createTime = System.currentTimeMillis();
521                }
522
523                public WebSocketSession getSession() {
524                        return this.session;
525                }
526
527                public long getCreateTime() {
528                        return this.createTime;
529                }
530
531                public void setHasHandledMessages() {
532                        this.hasHandledMessages = true;
533                }
534
535                public boolean hasHandledMessages() {
536                        return this.hasHandledMessages;
537                }
538
539                @Override
540                public String toString() {
541                        return "WebSocketSessionHolder[session=" + this.session + ", createTime=" +
542                                        this.createTime + ", hasHandledMessages=" + this.hasHandledMessages + "]";
543                }
544        }
545
546
547        private class Stats {
548
549                private final AtomicInteger total = new AtomicInteger();
550
551                private final AtomicInteger webSocket = new AtomicInteger();
552
553                private final AtomicInteger httpStreaming = new AtomicInteger();
554
555                private final AtomicInteger httpPolling = new AtomicInteger();
556
557                private final AtomicInteger limitExceeded = new AtomicInteger();
558
559                private final AtomicInteger noMessagesReceived = new AtomicInteger();
560
561                private final AtomicInteger transportError = new AtomicInteger();
562
563                public void incrementSessionCount(WebSocketSession session) {
564                        getCountFor(session).incrementAndGet();
565                        this.total.incrementAndGet();
566                }
567
568                public void decrementSessionCount(WebSocketSession session) {
569                        getCountFor(session).decrementAndGet();
570                }
571
572                public void incrementLimitExceededCount() {
573                        this.limitExceeded.incrementAndGet();
574                }
575
576                public void incrementNoMessagesReceivedCount() {
577                        this.noMessagesReceived.incrementAndGet();
578                }
579
580                public void incrementTransportError() {
581                        this.transportError.incrementAndGet();
582                }
583
584                private AtomicInteger getCountFor(WebSocketSession session) {
585                        if (session instanceof PollingSockJsSession) {
586                                return this.httpPolling;
587                        }
588                        else if (session instanceof StreamingSockJsSession) {
589                                return this.httpStreaming;
590                        }
591                        else {
592                                return this.webSocket;
593                        }
594                }
595
596                public String toString() {
597                        return SubProtocolWebSocketHandler.this.sessions.size() +
598                                        " current WS(" + this.webSocket.get() +
599                                        ")-HttpStream(" + this.httpStreaming.get() +
600                                        ")-HttpPoll(" + this.httpPolling.get() + "), " +
601                                        this.total.get() + " total, " +
602                                        (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" +
603                                        this.noMessagesReceived.get() + " connect failure, " +
604                                        this.limitExceeded.get() + " send limit, " +
605                                        this.transportError.get() + " transport error)";
606                }
607        }
608
609}