001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.util.ArrayList; 020import java.util.Collections; 021import java.util.LinkedHashSet; 022import java.util.List; 023import java.util.Map; 024import java.util.Set; 025import java.util.TreeMap; 026import java.util.concurrent.ConcurrentHashMap; 027import java.util.concurrent.atomic.AtomicInteger; 028import java.util.concurrent.locks.ReentrantLock; 029 030import org.apache.commons.logging.Log; 031import org.apache.commons.logging.LogFactory; 032 033import org.springframework.context.SmartLifecycle; 034import org.springframework.messaging.Message; 035import org.springframework.messaging.MessageChannel; 036import org.springframework.messaging.MessageHandler; 037import org.springframework.messaging.MessagingException; 038import org.springframework.messaging.SubscribableChannel; 039import org.springframework.util.Assert; 040import org.springframework.util.CollectionUtils; 041import org.springframework.util.StringUtils; 042import org.springframework.web.socket.CloseStatus; 043import org.springframework.web.socket.SubProtocolCapable; 044import org.springframework.web.socket.WebSocketHandler; 045import org.springframework.web.socket.WebSocketMessage; 046import org.springframework.web.socket.WebSocketSession; 047import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; 048import org.springframework.web.socket.handler.SessionLimitExceededException; 049import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; 050import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession; 051 052/** 053 * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket 054 * messages to a {@link SubProtocolHandler} along with a {@link MessageChannel} to which 055 * the sub-protocol handler can send messages from WebSocket clients to the application. 056 * 057 * <p>Also an implementation of {@link MessageHandler} that finds the WebSocket session 058 * associated with the {@link Message} and passes it, along with the message, to the 059 * sub-protocol handler to send messages from the application back to the client. 060 * 061 * @author Rossen Stoyanchev 062 * @author Juergen Hoeller 063 * @author Andy Wilkinson 064 * @author Artem Bilan 065 * @since 4.0 066 */ 067public class SubProtocolWebSocketHandler 068 implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle { 069 070 /** 071 * Sessions connected to this handler use a sub-protocol. Hence we expect to 072 * receive some client messages. If we don't receive any within a minute, the 073 * connection isn't doing well (proxy issue, slow network?) and can be closed. 074 * @see #checkSessions() 075 */ 076 private static final int TIME_TO_FIRST_MESSAGE = 60 * 1000; 077 078 079 private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); 080 081 082 private final MessageChannel clientInboundChannel; 083 084 private final SubscribableChannel clientOutboundChannel; 085 086 private final Map<String, SubProtocolHandler> protocolHandlerLookup = 087 new TreeMap<String, SubProtocolHandler>(String.CASE_INSENSITIVE_ORDER); 088 089 private final Set<SubProtocolHandler> protocolHandlers = new LinkedHashSet<SubProtocolHandler>(); 090 091 private SubProtocolHandler defaultProtocolHandler; 092 093 private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>(); 094 095 private int sendTimeLimit = 10 * 1000; 096 097 private int sendBufferSizeLimit = 512 * 1024; 098 099 private volatile long lastSessionCheckTime = System.currentTimeMillis(); 100 101 private final ReentrantLock sessionCheckLock = new ReentrantLock(); 102 103 private final Stats stats = new Stats(); 104 105 private volatile boolean running = false; 106 107 private final Object lifecycleMonitor = new Object(); 108 109 110 /** 111 * Create a new {@code SubProtocolWebSocketHandler} for the given inbound and outbound channels. 112 * @param clientInboundChannel the inbound {@code MessageChannel} 113 * @param clientOutboundChannel the outbound {@code MessageChannel} 114 */ 115 public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) { 116 Assert.notNull(clientInboundChannel, "Inbound MessageChannel must not be null"); 117 Assert.notNull(clientOutboundChannel, "Outbound MessageChannel must not be null"); 118 this.clientInboundChannel = clientInboundChannel; 119 this.clientOutboundChannel = clientOutboundChannel; 120 } 121 122 123 /** 124 * Configure one or more handlers to use depending on the sub-protocol requested by 125 * the client in the WebSocket handshake request. 126 * @param protocolHandlers the sub-protocol handlers to use 127 */ 128 public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) { 129 this.protocolHandlerLookup.clear(); 130 this.protocolHandlers.clear(); 131 for (SubProtocolHandler handler : protocolHandlers) { 132 addProtocolHandler(handler); 133 } 134 } 135 136 public List<SubProtocolHandler> getProtocolHandlers() { 137 return new ArrayList<SubProtocolHandler>(this.protocolHandlers); 138 } 139 140 /** 141 * Register a sub-protocol handler. 142 */ 143 public void addProtocolHandler(SubProtocolHandler handler) { 144 List<String> protocols = handler.getSupportedProtocols(); 145 if (CollectionUtils.isEmpty(protocols)) { 146 if (logger.isErrorEnabled()) { 147 logger.error("No sub-protocols for " + handler); 148 } 149 return; 150 } 151 for (String protocol : protocols) { 152 SubProtocolHandler replaced = this.protocolHandlerLookup.put(protocol, handler); 153 if (replaced != null && replaced != handler) { 154 throw new IllegalStateException("Cannot map " + handler + 155 " to protocol '" + protocol + "': already mapped to " + replaced + "."); 156 } 157 } 158 this.protocolHandlers.add(handler); 159 } 160 161 /** 162 * Return the sub-protocols keyed by protocol name. 163 */ 164 public Map<String, SubProtocolHandler> getProtocolHandlerMap() { 165 return this.protocolHandlerLookup; 166 } 167 168 /** 169 * Set the {@link SubProtocolHandler} to use when the client did not request a 170 * sub-protocol. 171 * @param defaultProtocolHandler the default handler 172 */ 173 public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) { 174 this.defaultProtocolHandler = defaultProtocolHandler; 175 if (this.protocolHandlerLookup.isEmpty()) { 176 setProtocolHandlers(Collections.singletonList(defaultProtocolHandler)); 177 } 178 } 179 180 /** 181 * Return the default sub-protocol handler to use. 182 */ 183 public SubProtocolHandler getDefaultProtocolHandler() { 184 return this.defaultProtocolHandler; 185 } 186 187 /** 188 * Return all supported protocols. 189 */ 190 public List<String> getSubProtocols() { 191 return new ArrayList<String>(this.protocolHandlerLookup.keySet()); 192 } 193 194 /** 195 * Specify the send-time limit (milliseconds). 196 * @see ConcurrentWebSocketSessionDecorator 197 */ 198 public void setSendTimeLimit(int sendTimeLimit) { 199 this.sendTimeLimit = sendTimeLimit; 200 } 201 202 /** 203 * Return the send-time limit (milliseconds). 204 */ 205 public int getSendTimeLimit() { 206 return this.sendTimeLimit; 207 } 208 209 /** 210 * Specify the buffer-size limit (number of bytes). 211 * @see ConcurrentWebSocketSessionDecorator 212 */ 213 public void setSendBufferSizeLimit(int sendBufferSizeLimit) { 214 this.sendBufferSizeLimit = sendBufferSizeLimit; 215 } 216 217 /** 218 * Return the buffer-size limit (number of bytes). 219 */ 220 public int getSendBufferSizeLimit() { 221 return this.sendBufferSizeLimit; 222 } 223 224 /** 225 * Return a String describing internal state and counters. 226 */ 227 public String getStatsInfo() { 228 return this.stats.toString(); 229 } 230 231 232 @Override 233 public boolean isAutoStartup() { 234 return true; 235 } 236 237 @Override 238 public int getPhase() { 239 return Integer.MAX_VALUE; 240 } 241 242 @Override 243 public final void start() { 244 Assert.isTrue(this.defaultProtocolHandler != null || !this.protocolHandlers.isEmpty(), "No handlers"); 245 246 synchronized (this.lifecycleMonitor) { 247 this.clientOutboundChannel.subscribe(this); 248 this.running = true; 249 } 250 } 251 252 @Override 253 public final void stop() { 254 synchronized (this.lifecycleMonitor) { 255 this.running = false; 256 this.clientOutboundChannel.unsubscribe(this); 257 } 258 259 // Proactively notify all active WebSocket sessions 260 for (WebSocketSessionHolder holder : this.sessions.values()) { 261 try { 262 holder.getSession().close(CloseStatus.GOING_AWAY); 263 } 264 catch (Throwable ex) { 265 if (logger.isWarnEnabled()) { 266 logger.warn("Failed to close '" + holder.getSession() + "': " + ex); 267 } 268 } 269 } 270 } 271 272 @Override 273 public final void stop(Runnable callback) { 274 synchronized (this.lifecycleMonitor) { 275 stop(); 276 callback.run(); 277 } 278 } 279 280 @Override 281 public final boolean isRunning() { 282 return this.running; 283 } 284 285 286 @Override 287 public void afterConnectionEstablished(WebSocketSession session) throws Exception { 288 // WebSocketHandlerDecorator could close the session 289 if (!session.isOpen()) { 290 return; 291 } 292 293 this.stats.incrementSessionCount(session); 294 session = decorateSession(session); 295 this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); 296 findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); 297 } 298 299 /** 300 * Handle an inbound message from a WebSocket client. 301 */ 302 @Override 303 public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception { 304 WebSocketSessionHolder holder = this.sessions.get(session.getId()); 305 if (holder != null) { 306 session = holder.getSession(); 307 } 308 SubProtocolHandler protocolHandler = findProtocolHandler(session); 309 protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); 310 if (holder != null) { 311 holder.setHasHandledMessages(); 312 } 313 checkSessions(); 314 } 315 316 /** 317 * Handle an outbound Spring Message to a WebSocket client. 318 */ 319 @Override 320 public void handleMessage(Message<?> message) throws MessagingException { 321 String sessionId = resolveSessionId(message); 322 if (sessionId == null) { 323 if (logger.isErrorEnabled()) { 324 logger.error("Could not find session id in " + message); 325 } 326 return; 327 } 328 329 WebSocketSessionHolder holder = this.sessions.get(sessionId); 330 if (holder == null) { 331 if (logger.isDebugEnabled()) { 332 // The broker may not have removed the session yet 333 logger.debug("No session for " + message); 334 } 335 return; 336 } 337 338 WebSocketSession session = holder.getSession(); 339 try { 340 findProtocolHandler(session).handleMessageToClient(session, message); 341 } 342 catch (SessionLimitExceededException ex) { 343 try { 344 if (logger.isDebugEnabled()) { 345 logger.debug("Terminating '" + session + "'", ex); 346 } 347 this.stats.incrementLimitExceededCount(); 348 clearSession(session, ex.getStatus()); // clear first, session may be unresponsive 349 session.close(ex.getStatus()); 350 } 351 catch (Exception secondException) { 352 logger.debug("Failure while closing session " + sessionId + ".", secondException); 353 } 354 } 355 catch (Exception ex) { 356 // Could be part of normal workflow (e.g. browser tab closed) 357 if (logger.isDebugEnabled()) { 358 logger.debug("Failed to send message to client in " + session + ": " + message, ex); 359 } 360 } 361 } 362 363 @Override 364 public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { 365 this.stats.incrementTransportError(); 366 } 367 368 @Override 369 public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { 370 clearSession(session, closeStatus); 371 } 372 373 @Override 374 public boolean supportsPartialMessages() { 375 return false; 376 } 377 378 379 /** 380 * Decorate the given {@link WebSocketSession}, if desired. 381 * <p>The default implementation builds a {@link ConcurrentWebSocketSessionDecorator} 382 * with the configured {@link #getSendTimeLimit() send-time limit} and 383 * {@link #getSendBufferSizeLimit() buffer-size limit}. 384 * @param session the original {@code WebSocketSession} 385 * @return the decorated {@code WebSocketSession}, or potentially the given session as-is 386 * @since 4.3.13 387 */ 388 protected WebSocketSession decorateSession(WebSocketSession session) { 389 return new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); 390 } 391 392 /** 393 * Find a {@link SubProtocolHandler} for the given session. 394 * @param session the {@code WebSocketSession} to find a handler for 395 */ 396 protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) { 397 String protocol = null; 398 try { 399 protocol = session.getAcceptedProtocol(); 400 } 401 catch (Exception ex) { 402 // Shouldn't happen 403 logger.error("Failed to obtain session.getAcceptedProtocol(): " + 404 "will use the default protocol handler (if configured).", ex); 405 } 406 407 SubProtocolHandler handler; 408 if (!StringUtils.isEmpty(protocol)) { 409 handler = this.protocolHandlerLookup.get(protocol); 410 if (handler == null) { 411 throw new IllegalStateException( 412 "No handler for '" + protocol + "' among " + this.protocolHandlerLookup); 413 } 414 } 415 else { 416 if (this.defaultProtocolHandler != null) { 417 handler = this.defaultProtocolHandler; 418 } 419 else if (this.protocolHandlers.size() == 1) { 420 handler = this.protocolHandlers.iterator().next(); 421 } 422 else { 423 throw new IllegalStateException("Multiple protocol handlers configured and " + 424 "no protocol was negotiated. Consider configuring a default SubProtocolHandler."); 425 } 426 } 427 return handler; 428 } 429 430 private String resolveSessionId(Message<?> message) { 431 for (SubProtocolHandler handler : this.protocolHandlerLookup.values()) { 432 String sessionId = handler.resolveSessionId(message); 433 if (sessionId != null) { 434 return sessionId; 435 } 436 } 437 if (this.defaultProtocolHandler != null) { 438 String sessionId = this.defaultProtocolHandler.resolveSessionId(message); 439 if (sessionId != null) { 440 return sessionId; 441 } 442 } 443 return null; 444 } 445 446 /** 447 * When a session is connected through a higher-level protocol it has a chance 448 * to use heartbeat management to shut down sessions that are too slow to send 449 * or receive messages. However, after a WebSocketSession is established and 450 * before the higher level protocol is fully connected there is a possibility for 451 * sessions to hang. This method checks and closes any sessions that have been 452 * connected for more than 60 seconds without having received a single message. 453 */ 454 private void checkSessions() { 455 long currentTime = System.currentTimeMillis(); 456 if (!isRunning() || (currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE)) { 457 return; 458 } 459 460 if (this.sessionCheckLock.tryLock()) { 461 try { 462 for (WebSocketSessionHolder holder : this.sessions.values()) { 463 if (holder.hasHandledMessages()) { 464 continue; 465 } 466 long timeSinceCreated = currentTime - holder.getCreateTime(); 467 if (timeSinceCreated < TIME_TO_FIRST_MESSAGE) { 468 continue; 469 } 470 WebSocketSession session = holder.getSession(); 471 if (logger.isInfoEnabled()) { 472 logger.info("No messages received after " + timeSinceCreated + " ms. " + 473 "Closing " + holder.getSession() + "."); 474 } 475 try { 476 this.stats.incrementNoMessagesReceivedCount(); 477 session.close(CloseStatus.SESSION_NOT_RELIABLE); 478 } 479 catch (Throwable ex) { 480 if (logger.isWarnEnabled()) { 481 logger.warn("Failed to close unreliable " + session, ex); 482 } 483 } 484 } 485 } 486 finally { 487 this.lastSessionCheckTime = currentTime; 488 this.sessionCheckLock.unlock(); 489 } 490 } 491 } 492 493 private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception { 494 if (logger.isDebugEnabled()) { 495 logger.debug("Clearing session " + session.getId()); 496 } 497 if (this.sessions.remove(session.getId()) != null) { 498 this.stats.decrementSessionCount(session); 499 } 500 findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel); 501 } 502 503 504 @Override 505 public String toString() { 506 return "SubProtocolWebSocketHandler" + this.protocolHandlers; 507 } 508 509 510 private static class WebSocketSessionHolder { 511 512 private final WebSocketSession session; 513 514 private final long createTime; 515 516 private volatile boolean hasHandledMessages; 517 518 public WebSocketSessionHolder(WebSocketSession session) { 519 this.session = session; 520 this.createTime = System.currentTimeMillis(); 521 } 522 523 public WebSocketSession getSession() { 524 return this.session; 525 } 526 527 public long getCreateTime() { 528 return this.createTime; 529 } 530 531 public void setHasHandledMessages() { 532 this.hasHandledMessages = true; 533 } 534 535 public boolean hasHandledMessages() { 536 return this.hasHandledMessages; 537 } 538 539 @Override 540 public String toString() { 541 return "WebSocketSessionHolder[session=" + this.session + ", createTime=" + 542 this.createTime + ", hasHandledMessages=" + this.hasHandledMessages + "]"; 543 } 544 } 545 546 547 private class Stats { 548 549 private final AtomicInteger total = new AtomicInteger(); 550 551 private final AtomicInteger webSocket = new AtomicInteger(); 552 553 private final AtomicInteger httpStreaming = new AtomicInteger(); 554 555 private final AtomicInteger httpPolling = new AtomicInteger(); 556 557 private final AtomicInteger limitExceeded = new AtomicInteger(); 558 559 private final AtomicInteger noMessagesReceived = new AtomicInteger(); 560 561 private final AtomicInteger transportError = new AtomicInteger(); 562 563 public void incrementSessionCount(WebSocketSession session) { 564 getCountFor(session).incrementAndGet(); 565 this.total.incrementAndGet(); 566 } 567 568 public void decrementSessionCount(WebSocketSession session) { 569 getCountFor(session).decrementAndGet(); 570 } 571 572 public void incrementLimitExceededCount() { 573 this.limitExceeded.incrementAndGet(); 574 } 575 576 public void incrementNoMessagesReceivedCount() { 577 this.noMessagesReceived.incrementAndGet(); 578 } 579 580 public void incrementTransportError() { 581 this.transportError.incrementAndGet(); 582 } 583 584 private AtomicInteger getCountFor(WebSocketSession session) { 585 if (session instanceof PollingSockJsSession) { 586 return this.httpPolling; 587 } 588 else if (session instanceof StreamingSockJsSession) { 589 return this.httpStreaming; 590 } 591 else { 592 return this.webSocket; 593 } 594 } 595 596 public String toString() { 597 return SubProtocolWebSocketHandler.this.sessions.size() + 598 " current WS(" + this.webSocket.get() + 599 ")-HttpStream(" + this.httpStreaming.get() + 600 ")-HttpPoll(" + this.httpPolling.get() + "), " + 601 this.total.get() + " total, " + 602 (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" + 603 this.noMessagesReceived.get() + " connect failure, " + 604 this.limitExceeded.get() + " send limit, " + 605 this.transportError.get() + " transport error)"; 606 } 607 } 608 609}