001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.io.IOException; 020import java.net.URI; 021import java.nio.ByteBuffer; 022import java.util.ArrayList; 023import java.util.Collections; 024import java.util.List; 025import java.util.concurrent.ScheduledFuture; 026 027import org.apache.commons.logging.Log; 028import org.apache.commons.logging.LogFactory; 029 030import org.springframework.context.Lifecycle; 031import org.springframework.context.SmartLifecycle; 032import org.springframework.lang.Nullable; 033import org.springframework.messaging.Message; 034import org.springframework.messaging.simp.stomp.BufferingStompDecoder; 035import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession; 036import org.springframework.messaging.simp.stomp.StompClientSupport; 037import org.springframework.messaging.simp.stomp.StompDecoder; 038import org.springframework.messaging.simp.stomp.StompEncoder; 039import org.springframework.messaging.simp.stomp.StompHeaderAccessor; 040import org.springframework.messaging.simp.stomp.StompHeaders; 041import org.springframework.messaging.simp.stomp.StompSession; 042import org.springframework.messaging.simp.stomp.StompSessionHandler; 043import org.springframework.messaging.support.MessageHeaderAccessor; 044import org.springframework.messaging.tcp.TcpConnection; 045import org.springframework.messaging.tcp.TcpConnectionHandler; 046import org.springframework.scheduling.TaskScheduler; 047import org.springframework.util.Assert; 048import org.springframework.util.MimeTypeUtils; 049import org.springframework.util.concurrent.ListenableFuture; 050import org.springframework.util.concurrent.ListenableFutureCallback; 051import org.springframework.util.concurrent.SettableListenableFuture; 052import org.springframework.web.socket.BinaryMessage; 053import org.springframework.web.socket.CloseStatus; 054import org.springframework.web.socket.TextMessage; 055import org.springframework.web.socket.WebSocketHandler; 056import org.springframework.web.socket.WebSocketHttpHeaders; 057import org.springframework.web.socket.WebSocketMessage; 058import org.springframework.web.socket.WebSocketSession; 059import org.springframework.web.socket.client.WebSocketClient; 060import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator; 061import org.springframework.web.socket.sockjs.transport.SockJsSession; 062import org.springframework.web.util.UriComponentsBuilder; 063 064/** 065 * A STOMP over WebSocket client that connects using an implementation of 066 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} 067 * including {@link org.springframework.web.socket.sockjs.client.SockJsClient 068 * SockJsClient}. 069 * 070 * @author Rossen Stoyanchev 071 * @since 4.2 072 */ 073public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle { 074 075 private static final Log logger = LogFactory.getLog(WebSocketStompClient.class); 076 077 private final WebSocketClient webSocketClient; 078 079 private int inboundMessageSizeLimit = 64 * 1024; 080 081 private boolean autoStartup = true; 082 083 private int phase = DEFAULT_PHASE; 084 085 private volatile boolean running = false; 086 087 088 /** 089 * Class constructor. Sets {@link #setDefaultHeartbeat} to "0,0" but will 090 * reset it back to the preferred "10000,10000" when a 091 * {@link #setTaskScheduler} is configured. 092 * @param webSocketClient the WebSocket client to connect with 093 */ 094 public WebSocketStompClient(WebSocketClient webSocketClient) { 095 Assert.notNull(webSocketClient, "WebSocketClient is required"); 096 this.webSocketClient = webSocketClient; 097 setDefaultHeartbeat(new long[] {0, 0}); 098 } 099 100 101 /** 102 * Return the configured WebSocketClient. 103 */ 104 public WebSocketClient getWebSocketClient() { 105 return this.webSocketClient; 106 } 107 108 /** 109 * {@inheritDoc} 110 * <p>Also automatically sets the {@link #setDefaultHeartbeat defaultHeartbeat} 111 * property to "10000,10000" if it is currently set to "0,0". 112 */ 113 @Override 114 public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) { 115 if (!isDefaultHeartbeatEnabled()) { 116 setDefaultHeartbeat(new long[] {10000, 10000}); 117 } 118 super.setTaskScheduler(taskScheduler); 119 } 120 121 /** 122 * Configure the maximum size allowed for inbound STOMP message. 123 * Since a STOMP message can be received in multiple WebSocket messages, 124 * buffering may be required and this property determines the maximum buffer 125 * size per message. 126 * <p>By default this is set to 64 * 1024 (64K). 127 */ 128 public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) { 129 this.inboundMessageSizeLimit = inboundMessageSizeLimit; 130 } 131 132 /** 133 * Get the configured inbound message buffer size in bytes. 134 */ 135 public int getInboundMessageSizeLimit() { 136 return this.inboundMessageSizeLimit; 137 } 138 139 /** 140 * Set whether to auto-start the contained WebSocketClient when the Spring 141 * context has been refreshed. 142 * <p>Default is "true". 143 */ 144 public void setAutoStartup(boolean autoStartup) { 145 this.autoStartup = autoStartup; 146 } 147 148 /** 149 * Return the value for the 'autoStartup' property. If "true", this client 150 * will automatically start and stop the contained WebSocketClient. 151 */ 152 @Override 153 public boolean isAutoStartup() { 154 return this.autoStartup; 155 } 156 157 /** 158 * Specify the phase in which the WebSocket client should be started and 159 * subsequently closed. The startup order proceeds from lowest to highest, 160 * and the shutdown order is the reverse of that. 161 * <p>By default this is Integer.MAX_VALUE meaning that the WebSocket client 162 * is started as late as possible and stopped as soon as possible. 163 */ 164 public void setPhase(int phase) { 165 this.phase = phase; 166 } 167 168 /** 169 * Return the configured phase. 170 */ 171 @Override 172 public int getPhase() { 173 return this.phase; 174 } 175 176 177 @Override 178 public void start() { 179 if (!isRunning()) { 180 this.running = true; 181 if (getWebSocketClient() instanceof Lifecycle) { 182 ((Lifecycle) getWebSocketClient()).start(); 183 } 184 } 185 186 } 187 188 @Override 189 public void stop() { 190 if (isRunning()) { 191 this.running = false; 192 if (getWebSocketClient() instanceof Lifecycle) { 193 ((Lifecycle) getWebSocketClient()).stop(); 194 } 195 } 196 } 197 198 @Override 199 public boolean isRunning() { 200 return this.running; 201 } 202 203 204 /** 205 * Connect to the given WebSocket URL and notify the given 206 * {@link org.springframework.messaging.simp.stomp.StompSessionHandler} 207 * when connected on the STOMP level after the CONNECTED frame is received. 208 * @param url the url to connect to 209 * @param handler the session handler 210 * @param uriVars the URI variables to expand into the URL 211 * @return a ListenableFuture for access to the session when ready for use 212 */ 213 public ListenableFuture<StompSession> connect(String url, StompSessionHandler handler, Object... uriVars) { 214 return connect(url, null, handler, uriVars); 215 } 216 217 /** 218 * An overloaded version of 219 * {@link #connect(String, StompSessionHandler, Object...)} that also 220 * accepts {@link WebSocketHttpHeaders} to use for the WebSocket handshake. 221 * @param url the url to connect to 222 * @param handshakeHeaders the headers for the WebSocket handshake 223 * @param handler the session handler 224 * @param uriVariables the URI variables to expand into the URL 225 * @return a ListenableFuture for access to the session when ready for use 226 */ 227 public ListenableFuture<StompSession> connect(String url, @Nullable WebSocketHttpHeaders handshakeHeaders, 228 StompSessionHandler handler, Object... uriVariables) { 229 230 return connect(url, handshakeHeaders, null, handler, uriVariables); 231 } 232 233 /** 234 * An overloaded version of 235 * {@link #connect(String, StompSessionHandler, Object...)} that also accepts 236 * {@link WebSocketHttpHeaders} to use for the WebSocket handshake and 237 * {@link StompHeaders} for the STOMP CONNECT frame. 238 * @param url the url to connect to 239 * @param handshakeHeaders headers for the WebSocket handshake 240 * @param connectHeaders headers for the STOMP CONNECT frame 241 * @param handler the session handler 242 * @param uriVariables the URI variables to expand into the URL 243 * @return a ListenableFuture for access to the session when ready for use 244 */ 245 public ListenableFuture<StompSession> connect(String url, @Nullable WebSocketHttpHeaders handshakeHeaders, 246 @Nullable StompHeaders connectHeaders, StompSessionHandler handler, Object... uriVariables) { 247 248 Assert.notNull(url, "'url' must not be null"); 249 URI uri = UriComponentsBuilder.fromUriString(url).buildAndExpand(uriVariables).encode().toUri(); 250 return connect(uri, handshakeHeaders, connectHeaders, handler); 251 } 252 253 /** 254 * An overloaded version of 255 * {@link #connect(String, WebSocketHttpHeaders, StompSessionHandler, Object...)} 256 * that accepts a fully prepared {@link java.net.URI}. 257 * @param url the url to connect to 258 * @param handshakeHeaders the headers for the WebSocket handshake 259 * @param connectHeaders headers for the STOMP CONNECT frame 260 * @param sessionHandler the STOMP session handler 261 * @return a ListenableFuture for access to the session when ready for use 262 */ 263 public ListenableFuture<StompSession> connect(URI url, @Nullable WebSocketHttpHeaders handshakeHeaders, 264 @Nullable StompHeaders connectHeaders, StompSessionHandler sessionHandler) { 265 266 Assert.notNull(url, "'url' must not be null"); 267 ConnectionHandlingStompSession session = createSession(connectHeaders, sessionHandler); 268 WebSocketTcpConnectionHandlerAdapter adapter = new WebSocketTcpConnectionHandlerAdapter(session); 269 getWebSocketClient() 270 .doHandshake(new LoggingWebSocketHandlerDecorator(adapter), handshakeHeaders, url) 271 .addCallback(adapter); 272 return session.getSessionFuture(); 273 } 274 275 @Override 276 protected StompHeaders processConnectHeaders(@Nullable StompHeaders connectHeaders) { 277 connectHeaders = super.processConnectHeaders(connectHeaders); 278 if (connectHeaders.isHeartbeatEnabled()) { 279 Assert.state(getTaskScheduler() != null, "TaskScheduler must be set if heartbeats are enabled"); 280 } 281 return connectHeaders; 282 } 283 284 285 /** 286 * Adapt WebSocket to the TcpConnectionHandler and TcpConnection contracts. 287 */ 288 private class WebSocketTcpConnectionHandlerAdapter implements ListenableFutureCallback<WebSocketSession>, 289 WebSocketHandler, TcpConnection<byte[]> { 290 291 private final TcpConnectionHandler<byte[]> connectionHandler; 292 293 private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit()); 294 295 @Nullable 296 private volatile WebSocketSession session; 297 298 private volatile long lastReadTime = -1; 299 300 private volatile long lastWriteTime = -1; 301 302 private final List<ScheduledFuture<?>> inactivityTasks = new ArrayList<>(2); 303 304 public WebSocketTcpConnectionHandlerAdapter(TcpConnectionHandler<byte[]> connectionHandler) { 305 Assert.notNull(connectionHandler, "TcpConnectionHandler must not be null"); 306 this.connectionHandler = connectionHandler; 307 } 308 309 // ListenableFutureCallback implementation: handshake outcome 310 311 @Override 312 public void onSuccess(@Nullable WebSocketSession webSocketSession) { 313 } 314 315 @Override 316 public void onFailure(Throwable ex) { 317 this.connectionHandler.afterConnectFailure(ex); 318 } 319 320 // WebSocketHandler implementation 321 322 @Override 323 public void afterConnectionEstablished(WebSocketSession session) { 324 this.session = session; 325 this.connectionHandler.afterConnected(this); 326 } 327 328 @Override 329 public void handleMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) { 330 this.lastReadTime = (this.lastReadTime != -1 ? System.currentTimeMillis() : -1); 331 List<Message<byte[]>> messages; 332 try { 333 messages = this.codec.decode(webSocketMessage); 334 } 335 catch (Throwable ex) { 336 this.connectionHandler.handleFailure(ex); 337 return; 338 } 339 for (Message<byte[]> message : messages) { 340 this.connectionHandler.handleMessage(message); 341 } 342 } 343 344 @Override 345 public void handleTransportError(WebSocketSession session, Throwable ex) throws Exception { 346 this.connectionHandler.handleFailure(ex); 347 } 348 349 @Override 350 public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { 351 cancelInactivityTasks(); 352 this.connectionHandler.afterConnectionClosed(); 353 } 354 355 private void cancelInactivityTasks() { 356 for (ScheduledFuture<?> task : this.inactivityTasks) { 357 try { 358 task.cancel(true); 359 } 360 catch (Throwable ex) { 361 // Ignore 362 } 363 } 364 this.lastReadTime = -1; 365 this.lastWriteTime = -1; 366 this.inactivityTasks.clear(); 367 } 368 369 @Override 370 public boolean supportsPartialMessages() { 371 return false; 372 } 373 374 // TcpConnection implementation 375 376 @Override 377 public ListenableFuture<Void> send(Message<byte[]> message) { 378 updateLastWriteTime(); 379 SettableListenableFuture<Void> future = new SettableListenableFuture<>(); 380 try { 381 WebSocketSession session = this.session; 382 Assert.state(session != null, "No WebSocketSession available"); 383 session.sendMessage(this.codec.encode(message, session.getClass())); 384 future.set(null); 385 } 386 catch (Throwable ex) { 387 future.setException(ex); 388 } 389 finally { 390 updateLastWriteTime(); 391 } 392 return future; 393 } 394 395 private void updateLastWriteTime() { 396 long lastWriteTime = this.lastWriteTime; 397 if (lastWriteTime != -1) { 398 this.lastWriteTime = System.currentTimeMillis(); 399 } 400 } 401 402 @Override 403 public void onReadInactivity(final Runnable runnable, final long duration) { 404 Assert.state(getTaskScheduler() != null, "No TaskScheduler configured"); 405 this.lastReadTime = System.currentTimeMillis(); 406 this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(() -> { 407 if (System.currentTimeMillis() - this.lastReadTime > duration) { 408 try { 409 runnable.run(); 410 } 411 catch (Throwable ex) { 412 if (logger.isDebugEnabled()) { 413 logger.debug("ReadInactivityTask failure", ex); 414 } 415 } 416 } 417 }, duration / 2)); 418 } 419 420 @Override 421 public void onWriteInactivity(final Runnable runnable, final long duration) { 422 Assert.state(getTaskScheduler() != null, "No TaskScheduler configured"); 423 this.lastWriteTime = System.currentTimeMillis(); 424 this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(() -> { 425 if (System.currentTimeMillis() - this.lastWriteTime > duration) { 426 try { 427 runnable.run(); 428 } 429 catch (Throwable ex) { 430 if (logger.isDebugEnabled()) { 431 logger.debug("WriteInactivityTask failure", ex); 432 } 433 } 434 } 435 }, duration / 2)); 436 } 437 438 @Override 439 public void close() { 440 WebSocketSession session = this.session; 441 if (session != null) { 442 try { 443 session.close(); 444 } 445 catch (IOException ex) { 446 if (logger.isDebugEnabled()) { 447 logger.debug("Failed to close session: " + session.getId(), ex); 448 } 449 } 450 } 451 } 452 } 453 454 455 /** 456 * Encode and decode STOMP WebSocket messages. 457 */ 458 private static class StompWebSocketMessageCodec { 459 460 private static final StompEncoder ENCODER = new StompEncoder(); 461 462 private static final StompDecoder DECODER = new StompDecoder(); 463 464 private final BufferingStompDecoder bufferingDecoder; 465 466 public StompWebSocketMessageCodec(int messageSizeLimit) { 467 this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit); 468 } 469 470 public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) { 471 List<Message<byte[]>> result = Collections.emptyList(); 472 ByteBuffer byteBuffer; 473 if (webSocketMessage instanceof TextMessage) { 474 byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes()); 475 } 476 else if (webSocketMessage instanceof BinaryMessage) { 477 byteBuffer = ((BinaryMessage) webSocketMessage).getPayload(); 478 } 479 else { 480 return result; 481 } 482 result = this.bufferingDecoder.decode(byteBuffer); 483 if (result.isEmpty()) { 484 if (logger.isTraceEnabled()) { 485 logger.trace("Incomplete STOMP frame content received, bufferSize=" + 486 this.bufferingDecoder.getBufferSize() + ", bufferSizeLimit=" + 487 this.bufferingDecoder.getBufferSizeLimit() + "."); 488 } 489 } 490 return result; 491 } 492 493 public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) { 494 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 495 Assert.notNull(accessor, "No StompHeaderAccessor available"); 496 byte[] payload = message.getPayload(); 497 byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload); 498 499 boolean useBinary = (payload.length > 0 && 500 !(SockJsSession.class.isAssignableFrom(sessionType)) && 501 MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType())); 502 503 return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes)); 504 } 505 } 506 507}