001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.util.ArrayList; 020import java.util.Collections; 021import java.util.LinkedHashSet; 022import java.util.List; 023import java.util.Map; 024import java.util.Set; 025import java.util.TreeMap; 026import java.util.concurrent.ConcurrentHashMap; 027import java.util.concurrent.atomic.AtomicInteger; 028import java.util.concurrent.locks.ReentrantLock; 029 030import org.apache.commons.logging.Log; 031import org.apache.commons.logging.LogFactory; 032 033import org.springframework.context.SmartLifecycle; 034import org.springframework.lang.Nullable; 035import org.springframework.messaging.Message; 036import org.springframework.messaging.MessageChannel; 037import org.springframework.messaging.MessageHandler; 038import org.springframework.messaging.MessagingException; 039import org.springframework.messaging.SubscribableChannel; 040import org.springframework.util.Assert; 041import org.springframework.util.CollectionUtils; 042import org.springframework.util.StringUtils; 043import org.springframework.web.socket.CloseStatus; 044import org.springframework.web.socket.SubProtocolCapable; 045import org.springframework.web.socket.WebSocketHandler; 046import org.springframework.web.socket.WebSocketMessage; 047import org.springframework.web.socket.WebSocketSession; 048import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; 049import org.springframework.web.socket.handler.SessionLimitExceededException; 050import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; 051import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession; 052 053/** 054 * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket 055 * messages to a {@link SubProtocolHandler} along with a {@link MessageChannel} to which 056 * the sub-protocol handler can send messages from WebSocket clients to the application. 057 * 058 * <p>Also an implementation of {@link MessageHandler} that finds the WebSocket session 059 * associated with the {@link Message} and passes it, along with the message, to the 060 * sub-protocol handler to send messages from the application back to the client. 061 * 062 * @author Rossen Stoyanchev 063 * @author Juergen Hoeller 064 * @author Andy Wilkinson 065 * @author Artem Bilan 066 * @since 4.0 067 */ 068public class SubProtocolWebSocketHandler 069 implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle { 070 071 /** The default value for {@link #setTimeToFirstMessage(int) timeToFirstMessage}. */ 072 private static final int DEFAULT_TIME_TO_FIRST_MESSAGE = 60 * 1000; 073 074 075 private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); 076 077 078 private final MessageChannel clientInboundChannel; 079 080 private final SubscribableChannel clientOutboundChannel; 081 082 private final Map<String, SubProtocolHandler> protocolHandlerLookup = 083 new TreeMap<>(String.CASE_INSENSITIVE_ORDER); 084 085 private final Set<SubProtocolHandler> protocolHandlers = new LinkedHashSet<>(); 086 087 @Nullable 088 private SubProtocolHandler defaultProtocolHandler; 089 090 private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<>(); 091 092 private int sendTimeLimit = 10 * 1000; 093 094 private int sendBufferSizeLimit = 512 * 1024; 095 096 private int timeToFirstMessage = DEFAULT_TIME_TO_FIRST_MESSAGE; 097 098 private volatile long lastSessionCheckTime = System.currentTimeMillis(); 099 100 private final ReentrantLock sessionCheckLock = new ReentrantLock(); 101 102 private final DefaultStats stats = new DefaultStats(); 103 104 private volatile boolean running = false; 105 106 private final Object lifecycleMonitor = new Object(); 107 108 109 /** 110 * Create a new {@code SubProtocolWebSocketHandler} for the given inbound and outbound channels. 111 * @param clientInboundChannel the inbound {@code MessageChannel} 112 * @param clientOutboundChannel the outbound {@code MessageChannel} 113 */ 114 public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) { 115 Assert.notNull(clientInboundChannel, "Inbound MessageChannel must not be null"); 116 Assert.notNull(clientOutboundChannel, "Outbound MessageChannel must not be null"); 117 this.clientInboundChannel = clientInboundChannel; 118 this.clientOutboundChannel = clientOutboundChannel; 119 } 120 121 122 /** 123 * Configure one or more handlers to use depending on the sub-protocol requested by 124 * the client in the WebSocket handshake request. 125 * @param protocolHandlers the sub-protocol handlers to use 126 */ 127 public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) { 128 this.protocolHandlerLookup.clear(); 129 this.protocolHandlers.clear(); 130 for (SubProtocolHandler handler : protocolHandlers) { 131 addProtocolHandler(handler); 132 } 133 } 134 135 public List<SubProtocolHandler> getProtocolHandlers() { 136 return new ArrayList<>(this.protocolHandlers); 137 } 138 139 /** 140 * Register a sub-protocol handler. 141 */ 142 public void addProtocolHandler(SubProtocolHandler handler) { 143 List<String> protocols = handler.getSupportedProtocols(); 144 if (CollectionUtils.isEmpty(protocols)) { 145 if (logger.isErrorEnabled()) { 146 logger.error("No sub-protocols for " + handler); 147 } 148 return; 149 } 150 for (String protocol : protocols) { 151 SubProtocolHandler replaced = this.protocolHandlerLookup.put(protocol, handler); 152 if (replaced != null && replaced != handler) { 153 throw new IllegalStateException("Cannot map " + handler + 154 " to protocol '" + protocol + "': already mapped to " + replaced + "."); 155 } 156 } 157 this.protocolHandlers.add(handler); 158 } 159 160 /** 161 * Return the sub-protocols keyed by protocol name. 162 */ 163 public Map<String, SubProtocolHandler> getProtocolHandlerMap() { 164 return this.protocolHandlerLookup; 165 } 166 167 /** 168 * Set the {@link SubProtocolHandler} to use when the client did not request a 169 * sub-protocol. 170 * @param defaultProtocolHandler the default handler 171 */ 172 public void setDefaultProtocolHandler(@Nullable SubProtocolHandler defaultProtocolHandler) { 173 this.defaultProtocolHandler = defaultProtocolHandler; 174 if (this.protocolHandlerLookup.isEmpty()) { 175 setProtocolHandlers(Collections.singletonList(defaultProtocolHandler)); 176 } 177 } 178 179 /** 180 * Return the default sub-protocol handler to use. 181 */ 182 @Nullable 183 public SubProtocolHandler getDefaultProtocolHandler() { 184 return this.defaultProtocolHandler; 185 } 186 187 /** 188 * Return all supported protocols. 189 */ 190 @Override 191 public List<String> getSubProtocols() { 192 return new ArrayList<>(this.protocolHandlerLookup.keySet()); 193 } 194 195 /** 196 * Specify the send-time limit (milliseconds). 197 * @see ConcurrentWebSocketSessionDecorator 198 */ 199 public void setSendTimeLimit(int sendTimeLimit) { 200 this.sendTimeLimit = sendTimeLimit; 201 } 202 203 /** 204 * Return the send-time limit (milliseconds). 205 */ 206 public int getSendTimeLimit() { 207 return this.sendTimeLimit; 208 } 209 210 /** 211 * Specify the buffer-size limit (number of bytes). 212 * @see ConcurrentWebSocketSessionDecorator 213 */ 214 public void setSendBufferSizeLimit(int sendBufferSizeLimit) { 215 this.sendBufferSizeLimit = sendBufferSizeLimit; 216 } 217 218 /** 219 * Return the buffer-size limit (number of bytes). 220 */ 221 public int getSendBufferSizeLimit() { 222 return this.sendBufferSizeLimit; 223 } 224 225 /** 226 * Set the maximum time allowed in milliseconds after the WebSocket connection 227 * is established and before the first sub-protocol message is received. 228 * <p>This handler is for WebSocket connections that use a sub-protocol. 229 * Therefore, we expect the client to send at least one sub-protocol message 230 * in the beginning, or else we assume the connection isn't doing well, e.g. 231 * proxy issue, slow network, and can be closed. 232 * <p>By default this is set to {@code 60,000} (1 minute). 233 * @param timeToFirstMessage the maximum time allowed in milliseconds 234 * @since 5.1 235 * @see #checkSessions() 236 */ 237 public void setTimeToFirstMessage(int timeToFirstMessage) { 238 this.timeToFirstMessage = timeToFirstMessage; 239 } 240 241 /** 242 * Return the maximum time allowed after the WebSocket connection is 243 * established and before the first sub-protocol message. 244 * @since 5.1 245 */ 246 public int getTimeToFirstMessage() { 247 return this.timeToFirstMessage; 248 } 249 250 /** 251 * Return a String describing internal state and counters. 252 * Effectively {@code toString()} on {@link #getStats() getStats()}. 253 */ 254 public String getStatsInfo() { 255 return this.stats.toString(); 256 } 257 258 /** 259 * Return a structured object with various session counters. 260 * @since 5.2 261 */ 262 public Stats getStats() { 263 return this.stats; 264 } 265 266 267 268 @Override 269 public final void start() { 270 Assert.isTrue(this.defaultProtocolHandler != null || !this.protocolHandlers.isEmpty(), "No handlers"); 271 272 synchronized (this.lifecycleMonitor) { 273 this.clientOutboundChannel.subscribe(this); 274 this.running = true; 275 } 276 } 277 278 @Override 279 public final void stop() { 280 synchronized (this.lifecycleMonitor) { 281 this.running = false; 282 this.clientOutboundChannel.unsubscribe(this); 283 } 284 285 // Proactively notify all active WebSocket sessions 286 for (WebSocketSessionHolder holder : this.sessions.values()) { 287 try { 288 holder.getSession().close(CloseStatus.GOING_AWAY); 289 } 290 catch (Throwable ex) { 291 if (logger.isWarnEnabled()) { 292 logger.warn("Failed to close '" + holder.getSession() + "': " + ex); 293 } 294 } 295 } 296 } 297 298 @Override 299 public final void stop(Runnable callback) { 300 synchronized (this.lifecycleMonitor) { 301 stop(); 302 callback.run(); 303 } 304 } 305 306 @Override 307 public final boolean isRunning() { 308 return this.running; 309 } 310 311 312 @Override 313 public void afterConnectionEstablished(WebSocketSession session) throws Exception { 314 // WebSocketHandlerDecorator could close the session 315 if (!session.isOpen()) { 316 return; 317 } 318 319 this.stats.incrementSessionCount(session); 320 session = decorateSession(session); 321 this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); 322 findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); 323 } 324 325 /** 326 * Handle an inbound message from a WebSocket client. 327 */ 328 @Override 329 public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception { 330 WebSocketSessionHolder holder = this.sessions.get(session.getId()); 331 if (holder != null) { 332 session = holder.getSession(); 333 } 334 SubProtocolHandler protocolHandler = findProtocolHandler(session); 335 protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); 336 if (holder != null) { 337 holder.setHasHandledMessages(); 338 } 339 checkSessions(); 340 } 341 342 /** 343 * Handle an outbound Spring Message to a WebSocket client. 344 */ 345 @Override 346 public void handleMessage(Message<?> message) throws MessagingException { 347 String sessionId = resolveSessionId(message); 348 if (sessionId == null) { 349 if (logger.isErrorEnabled()) { 350 logger.error("Could not find session id in " + message); 351 } 352 return; 353 } 354 355 WebSocketSessionHolder holder = this.sessions.get(sessionId); 356 if (holder == null) { 357 if (logger.isDebugEnabled()) { 358 // The broker may not have removed the session yet 359 logger.debug("No session for " + message); 360 } 361 return; 362 } 363 364 WebSocketSession session = holder.getSession(); 365 try { 366 findProtocolHandler(session).handleMessageToClient(session, message); 367 } 368 catch (SessionLimitExceededException ex) { 369 try { 370 if (logger.isDebugEnabled()) { 371 logger.debug("Terminating '" + session + "'", ex); 372 } 373 else if (logger.isWarnEnabled()) { 374 logger.warn("Terminating '" + session + "': " + ex.getMessage()); 375 } 376 this.stats.incrementLimitExceededCount(); 377 clearSession(session, ex.getStatus()); // clear first, session may be unresponsive 378 session.close(ex.getStatus()); 379 } 380 catch (Exception secondException) { 381 logger.debug("Failure while closing session " + sessionId + ".", secondException); 382 } 383 } 384 catch (Exception ex) { 385 // Could be part of normal workflow (e.g. browser tab closed) 386 if (logger.isDebugEnabled()) { 387 logger.debug("Failed to send message to client in " + session + ": " + message, ex); 388 } 389 } 390 } 391 392 @Override 393 public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { 394 this.stats.incrementTransportError(); 395 } 396 397 @Override 398 public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { 399 clearSession(session, closeStatus); 400 } 401 402 @Override 403 public boolean supportsPartialMessages() { 404 return false; 405 } 406 407 408 /** 409 * Decorate the given {@link WebSocketSession}, if desired. 410 * <p>The default implementation builds a {@link ConcurrentWebSocketSessionDecorator} 411 * with the configured {@link #getSendTimeLimit() send-time limit} and 412 * {@link #getSendBufferSizeLimit() buffer-size limit}. 413 * @param session the original {@code WebSocketSession} 414 * @return the decorated {@code WebSocketSession}, or potentially the given session as-is 415 * @since 4.3.13 416 */ 417 protected WebSocketSession decorateSession(WebSocketSession session) { 418 return new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); 419 } 420 421 /** 422 * Find a {@link SubProtocolHandler} for the given session. 423 * @param session the {@code WebSocketSession} to find a handler for 424 */ 425 protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) { 426 String protocol = null; 427 try { 428 protocol = session.getAcceptedProtocol(); 429 } 430 catch (Exception ex) { 431 // Shouldn't happen 432 logger.error("Failed to obtain session.getAcceptedProtocol(): " + 433 "will use the default protocol handler (if configured).", ex); 434 } 435 436 SubProtocolHandler handler; 437 if (StringUtils.hasLength(protocol)) { 438 handler = this.protocolHandlerLookup.get(protocol); 439 if (handler == null) { 440 throw new IllegalStateException( 441 "No handler for '" + protocol + "' among " + this.protocolHandlerLookup); 442 } 443 } 444 else { 445 if (this.defaultProtocolHandler != null) { 446 handler = this.defaultProtocolHandler; 447 } 448 else if (this.protocolHandlers.size() == 1) { 449 handler = this.protocolHandlers.iterator().next(); 450 } 451 else { 452 throw new IllegalStateException("Multiple protocol handlers configured and " + 453 "no protocol was negotiated. Consider configuring a default SubProtocolHandler."); 454 } 455 } 456 return handler; 457 } 458 459 @Nullable 460 private String resolveSessionId(Message<?> message) { 461 for (SubProtocolHandler handler : this.protocolHandlerLookup.values()) { 462 String sessionId = handler.resolveSessionId(message); 463 if (sessionId != null) { 464 return sessionId; 465 } 466 } 467 if (this.defaultProtocolHandler != null) { 468 String sessionId = this.defaultProtocolHandler.resolveSessionId(message); 469 if (sessionId != null) { 470 return sessionId; 471 } 472 } 473 return null; 474 } 475 476 /** 477 * When a session is connected through a higher-level protocol it has a chance 478 * to use heartbeat management to shut down sessions that are too slow to send 479 * or receive messages. However, after a WebSocketSession is established and 480 * before the higher level protocol is fully connected there is a possibility for 481 * sessions to hang. This method checks and closes any sessions that have been 482 * connected for more than 60 seconds without having received a single message. 483 */ 484 private void checkSessions() { 485 long currentTime = System.currentTimeMillis(); 486 if (!isRunning() || (currentTime - this.lastSessionCheckTime < getTimeToFirstMessage())) { 487 return; 488 } 489 490 if (this.sessionCheckLock.tryLock()) { 491 try { 492 for (WebSocketSessionHolder holder : this.sessions.values()) { 493 if (holder.hasHandledMessages()) { 494 continue; 495 } 496 long timeSinceCreated = currentTime - holder.getCreateTime(); 497 if (timeSinceCreated < getTimeToFirstMessage()) { 498 continue; 499 } 500 WebSocketSession session = holder.getSession(); 501 if (logger.isInfoEnabled()) { 502 logger.info("No messages received after " + timeSinceCreated + " ms. " + 503 "Closing " + holder.getSession() + "."); 504 } 505 try { 506 this.stats.incrementNoMessagesReceivedCount(); 507 session.close(CloseStatus.SESSION_NOT_RELIABLE); 508 } 509 catch (Throwable ex) { 510 if (logger.isWarnEnabled()) { 511 logger.warn("Failed to close unreliable " + session, ex); 512 } 513 } 514 } 515 } 516 finally { 517 this.lastSessionCheckTime = currentTime; 518 this.sessionCheckLock.unlock(); 519 } 520 } 521 } 522 523 private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception { 524 if (logger.isDebugEnabled()) { 525 logger.debug("Clearing session " + session.getId()); 526 } 527 if (this.sessions.remove(session.getId()) != null) { 528 this.stats.decrementSessionCount(session); 529 } 530 findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel); 531 } 532 533 534 @Override 535 public String toString() { 536 return "SubProtocolWebSocketHandler" + this.protocolHandlers; 537 } 538 539 540 private static class WebSocketSessionHolder { 541 542 private final WebSocketSession session; 543 544 private final long createTime; 545 546 private volatile boolean hasHandledMessages; 547 548 public WebSocketSessionHolder(WebSocketSession session) { 549 this.session = session; 550 this.createTime = System.currentTimeMillis(); 551 } 552 553 public WebSocketSession getSession() { 554 return this.session; 555 } 556 557 public long getCreateTime() { 558 return this.createTime; 559 } 560 561 public void setHasHandledMessages() { 562 this.hasHandledMessages = true; 563 } 564 565 public boolean hasHandledMessages() { 566 return this.hasHandledMessages; 567 } 568 569 @Override 570 public String toString() { 571 return "WebSocketSessionHolder[session=" + this.session + ", createTime=" + 572 this.createTime + ", hasHandledMessages=" + this.hasHandledMessages + "]"; 573 } 574 } 575 576 577 /** 578 * Contract for access to session counters. 579 * @since 5.2 580 */ 581 public interface Stats { 582 583 int getTotalSessions(); 584 585 int getWebSocketSessions(); 586 587 int getHttpStreamingSessions(); 588 589 int getHttpPollingSessions(); 590 591 int getLimitExceededSessions(); 592 593 int getNoMessagesReceivedSessions(); 594 595 int getTransportErrorSessions(); 596 } 597 598 599 private class DefaultStats implements Stats { 600 601 private final AtomicInteger total = new AtomicInteger(); 602 603 private final AtomicInteger webSocket = new AtomicInteger(); 604 605 private final AtomicInteger httpStreaming = new AtomicInteger(); 606 607 private final AtomicInteger httpPolling = new AtomicInteger(); 608 609 private final AtomicInteger limitExceeded = new AtomicInteger(); 610 611 private final AtomicInteger noMessagesReceived = new AtomicInteger(); 612 613 private final AtomicInteger transportError = new AtomicInteger(); 614 615 @Override 616 public int getTotalSessions() { 617 return this.total.get(); 618 } 619 620 @Override 621 public int getWebSocketSessions() { 622 return this.webSocket.get(); 623 } 624 625 @Override 626 public int getHttpStreamingSessions() { 627 return this.httpStreaming.get(); 628 } 629 630 @Override 631 public int getHttpPollingSessions() { 632 return this.httpPolling.get(); 633 } 634 635 @Override 636 public int getLimitExceededSessions() { 637 return this.limitExceeded.get(); 638 } 639 640 @Override 641 public int getNoMessagesReceivedSessions() { 642 return this.noMessagesReceived.get(); 643 } 644 645 @Override 646 public int getTransportErrorSessions() { 647 return this.transportError.get(); 648 } 649 650 void incrementSessionCount(WebSocketSession session) { 651 getCountFor(session).incrementAndGet(); 652 this.total.incrementAndGet(); 653 } 654 655 void decrementSessionCount(WebSocketSession session) { 656 getCountFor(session).decrementAndGet(); 657 } 658 659 void incrementLimitExceededCount() { 660 this.limitExceeded.incrementAndGet(); 661 } 662 663 void incrementNoMessagesReceivedCount() { 664 this.noMessagesReceived.incrementAndGet(); 665 } 666 667 void incrementTransportError() { 668 this.transportError.incrementAndGet(); 669 } 670 671 AtomicInteger getCountFor(WebSocketSession session) { 672 if (session instanceof PollingSockJsSession) { 673 return this.httpPolling; 674 } 675 else if (session instanceof StreamingSockJsSession) { 676 return this.httpStreaming; 677 } 678 else { 679 return this.webSocket; 680 } 681 } 682 683 @Override 684 public String toString() { 685 return SubProtocolWebSocketHandler.this.sessions.size() + 686 " current WS(" + this.webSocket.get() + 687 ")-HttpStream(" + this.httpStreaming.get() + 688 ")-HttpPoll(" + this.httpPolling.get() + "), " + 689 this.total.get() + " total, " + 690 (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" + 691 this.noMessagesReceived.get() + " connect failure, " + 692 this.limitExceeded.get() + " send limit, " + 693 this.transportError.get() + " transport error)"; 694 } 695 } 696 697}