001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.io.IOException; 020import java.net.URI; 021import java.nio.ByteBuffer; 022import java.util.ArrayList; 023import java.util.Collections; 024import java.util.List; 025import java.util.concurrent.ScheduledFuture; 026 027import org.apache.commons.logging.Log; 028import org.apache.commons.logging.LogFactory; 029 030import org.springframework.context.Lifecycle; 031import org.springframework.context.SmartLifecycle; 032import org.springframework.messaging.Message; 033import org.springframework.messaging.simp.stomp.BufferingStompDecoder; 034import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession; 035import org.springframework.messaging.simp.stomp.StompClientSupport; 036import org.springframework.messaging.simp.stomp.StompDecoder; 037import org.springframework.messaging.simp.stomp.StompEncoder; 038import org.springframework.messaging.simp.stomp.StompHeaderAccessor; 039import org.springframework.messaging.simp.stomp.StompHeaders; 040import org.springframework.messaging.simp.stomp.StompSession; 041import org.springframework.messaging.simp.stomp.StompSessionHandler; 042import org.springframework.messaging.support.MessageHeaderAccessor; 043import org.springframework.messaging.tcp.TcpConnection; 044import org.springframework.messaging.tcp.TcpConnectionHandler; 045import org.springframework.scheduling.TaskScheduler; 046import org.springframework.util.Assert; 047import org.springframework.util.MimeTypeUtils; 048import org.springframework.util.concurrent.ListenableFuture; 049import org.springframework.util.concurrent.ListenableFutureCallback; 050import org.springframework.util.concurrent.SettableListenableFuture; 051import org.springframework.web.socket.BinaryMessage; 052import org.springframework.web.socket.CloseStatus; 053import org.springframework.web.socket.TextMessage; 054import org.springframework.web.socket.WebSocketHandler; 055import org.springframework.web.socket.WebSocketHttpHeaders; 056import org.springframework.web.socket.WebSocketMessage; 057import org.springframework.web.socket.WebSocketSession; 058import org.springframework.web.socket.client.WebSocketClient; 059import org.springframework.web.socket.sockjs.transport.SockJsSession; 060import org.springframework.web.util.UriComponentsBuilder; 061 062/** 063 * A STOMP over WebSocket client that connects using an implementation of 064 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} 065 * including {@link org.springframework.web.socket.sockjs.client.SockJsClient 066 * SockJsClient}. 067 * 068 * @author Rossen Stoyanchev 069 * @since 4.2 070 */ 071public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle { 072 073 private static final Log logger = LogFactory.getLog(WebSocketStompClient.class); 074 075 private final WebSocketClient webSocketClient; 076 077 private int inboundMessageSizeLimit = 64 * 1024; 078 079 private boolean autoStartup = true; 080 081 private int phase = Integer.MAX_VALUE; 082 083 private volatile boolean running = false; 084 085 086 /** 087 * Class constructor. Sets {@link #setDefaultHeartbeat} to "0,0" but will 088 * reset it back to the preferred "10000,10000" when a 089 * {@link #setTaskScheduler} is configured. 090 * @param webSocketClient the WebSocket client to connect with 091 */ 092 public WebSocketStompClient(WebSocketClient webSocketClient) { 093 Assert.notNull(webSocketClient, "WebSocketClient is required"); 094 this.webSocketClient = webSocketClient; 095 setDefaultHeartbeat(new long[] {0, 0}); 096 } 097 098 099 /** 100 * Return the configured WebSocketClient. 101 */ 102 public WebSocketClient getWebSocketClient() { 103 return this.webSocketClient; 104 } 105 106 /** 107 * {@inheritDoc} 108 * <p>Also automatically sets the {@link #setDefaultHeartbeat defaultHeartbeat} 109 * property to "10000,10000" if it is currently set to "0,0". 110 */ 111 @Override 112 public void setTaskScheduler(TaskScheduler taskScheduler) { 113 if (taskScheduler != null && !isDefaultHeartbeatEnabled()) { 114 setDefaultHeartbeat(new long[] {10000, 10000}); 115 } 116 super.setTaskScheduler(taskScheduler); 117 } 118 119 /** 120 * Configure the maximum size allowed for inbound STOMP message. 121 * Since a STOMP message can be received in multiple WebSocket messages, 122 * buffering may be required and this property determines the maximum buffer 123 * size per message. 124 * <p>By default this is set to 64 * 1024 (64K). 125 */ 126 public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) { 127 this.inboundMessageSizeLimit = inboundMessageSizeLimit; 128 } 129 130 /** 131 * Get the configured inbound message buffer size in bytes. 132 */ 133 public int getInboundMessageSizeLimit() { 134 return this.inboundMessageSizeLimit; 135 } 136 137 /** 138 * Set whether to auto-start the contained WebSocketClient when the Spring 139 * context has been refreshed. 140 * <p>Default is "true". 141 */ 142 public void setAutoStartup(boolean autoStartup) { 143 this.autoStartup = autoStartup; 144 } 145 146 /** 147 * Return the value for the 'autoStartup' property. If "true", this client 148 * will automatically start and stop the contained WebSocketClient. 149 */ 150 @Override 151 public boolean isAutoStartup() { 152 return this.autoStartup; 153 } 154 155 /** 156 * Specify the phase in which the WebSocket client should be started and 157 * subsequently closed. The startup order proceeds from lowest to highest, 158 * and the shutdown order is the reverse of that. 159 * <p>By default this is Integer.MAX_VALUE meaning that the WebSocket client 160 * is started as late as possible and stopped as soon as possible. 161 */ 162 public void setPhase(int phase) { 163 this.phase = phase; 164 } 165 166 /** 167 * Return the configured phase. 168 */ 169 @Override 170 public int getPhase() { 171 return this.phase; 172 } 173 174 175 @Override 176 public void start() { 177 if (!isRunning()) { 178 this.running = true; 179 if (getWebSocketClient() instanceof Lifecycle) { 180 ((Lifecycle) getWebSocketClient()).start(); 181 } 182 } 183 184 } 185 186 @Override 187 public void stop() { 188 if (isRunning()) { 189 this.running = false; 190 if (getWebSocketClient() instanceof Lifecycle) { 191 ((Lifecycle) getWebSocketClient()).stop(); 192 } 193 } 194 } 195 196 @Override 197 public void stop(Runnable callback) { 198 stop(); 199 callback.run(); 200 } 201 202 @Override 203 public boolean isRunning() { 204 return this.running; 205 } 206 207 208 /** 209 * Connect to the given WebSocket URL and notify the given 210 * {@link org.springframework.messaging.simp.stomp.StompSessionHandler} 211 * when connected on the STOMP level after the CONNECTED frame is received. 212 * @param url the url to connect to 213 * @param handler the session handler 214 * @param uriVars the URI variables to expand into the URL 215 * @return a ListenableFuture for access to the session when ready for use 216 */ 217 public ListenableFuture<StompSession> connect(String url, StompSessionHandler handler, Object... uriVars) { 218 return connect(url, null, handler, uriVars); 219 } 220 221 /** 222 * An overloaded version of 223 * {@link #connect(String, StompSessionHandler, Object...)} that also 224 * accepts {@link WebSocketHttpHeaders} to use for the WebSocket handshake. 225 * @param url the url to connect to 226 * @param handshakeHeaders the headers for the WebSocket handshake 227 * @param handler the session handler 228 * @param uriVariables the URI variables to expand into the URL 229 * @return a ListenableFuture for access to the session when ready for use 230 */ 231 public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders, 232 StompSessionHandler handler, Object... uriVariables) { 233 234 return connect(url, handshakeHeaders, null, handler, uriVariables); 235 } 236 237 /** 238 * An overloaded version of 239 * {@link #connect(String, StompSessionHandler, Object...)} that also accepts 240 * {@link WebSocketHttpHeaders} to use for the WebSocket handshake and 241 * {@link StompHeaders} for the STOMP CONNECT frame. 242 * @param url the url to connect to 243 * @param handshakeHeaders headers for the WebSocket handshake 244 * @param connectHeaders headers for the STOMP CONNECT frame 245 * @param handler the session handler 246 * @param uriVariables the URI variables to expand into the URL 247 * @return a ListenableFuture for access to the session when ready for use 248 */ 249 public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders, 250 StompHeaders connectHeaders, StompSessionHandler handler, Object... uriVariables) { 251 252 Assert.notNull(url, "'url' must not be null"); 253 URI uri = UriComponentsBuilder.fromUriString(url).buildAndExpand(uriVariables).encode().toUri(); 254 return connect(uri, handshakeHeaders, connectHeaders, handler); 255 } 256 257 /** 258 * An overloaded version of 259 * {@link #connect(String, WebSocketHttpHeaders, StompSessionHandler, Object...)} 260 * that accepts a fully prepared {@link java.net.URI}. 261 * @param url the url to connect to 262 * @param handshakeHeaders the headers for the WebSocket handshake 263 * @param connectHeaders headers for the STOMP CONNECT frame 264 * @param sessionHandler the STOMP session handler 265 * @return a ListenableFuture for access to the session when ready for use 266 */ 267 public ListenableFuture<StompSession> connect(URI url, WebSocketHttpHeaders handshakeHeaders, 268 StompHeaders connectHeaders, StompSessionHandler sessionHandler) { 269 270 Assert.notNull(url, "'url' must not be null"); 271 ConnectionHandlingStompSession session = createSession(connectHeaders, sessionHandler); 272 WebSocketTcpConnectionHandlerAdapter adapter = new WebSocketTcpConnectionHandlerAdapter(session); 273 getWebSocketClient().doHandshake(adapter, handshakeHeaders, url).addCallback(adapter); 274 return session.getSessionFuture(); 275 } 276 277 @Override 278 protected StompHeaders processConnectHeaders(StompHeaders connectHeaders) { 279 connectHeaders = super.processConnectHeaders(connectHeaders); 280 if (connectHeaders.isHeartbeatEnabled()) { 281 Assert.state(getTaskScheduler() != null, "TaskScheduler must be set if heartbeats are enabled"); 282 } 283 return connectHeaders; 284 } 285 286 287 /** 288 * Adapt WebSocket to the TcpConnectionHandler and TcpConnection contracts. 289 */ 290 private class WebSocketTcpConnectionHandlerAdapter implements ListenableFutureCallback<WebSocketSession>, 291 WebSocketHandler, TcpConnection<byte[]> { 292 293 private final TcpConnectionHandler<byte[]> connectionHandler; 294 295 private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit()); 296 297 private volatile WebSocketSession session; 298 299 private volatile long lastReadTime = -1; 300 301 private volatile long lastWriteTime = -1; 302 303 private final List<ScheduledFuture<?>> inactivityTasks = new ArrayList<ScheduledFuture<?>>(2); 304 305 public WebSocketTcpConnectionHandlerAdapter(TcpConnectionHandler<byte[]> connectionHandler) { 306 Assert.notNull(connectionHandler, "TcpConnectionHandler must not be null"); 307 this.connectionHandler = connectionHandler; 308 } 309 310 // ListenableFutureCallback implementation: handshake outcome 311 312 @Override 313 public void onSuccess(WebSocketSession webSocketSession) { 314 } 315 316 @Override 317 public void onFailure(Throwable ex) { 318 this.connectionHandler.afterConnectFailure(ex); 319 } 320 321 // WebSocketHandler implementation 322 323 @Override 324 public void afterConnectionEstablished(WebSocketSession session) { 325 this.session = session; 326 this.connectionHandler.afterConnected(this); 327 } 328 329 @Override 330 public void handleMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) { 331 this.lastReadTime = (this.lastReadTime != -1 ? System.currentTimeMillis() : -1); 332 List<Message<byte[]>> messages; 333 try { 334 messages = this.codec.decode(webSocketMessage); 335 } 336 catch (Throwable ex) { 337 this.connectionHandler.handleFailure(ex); 338 return; 339 } 340 for (Message<byte[]> message : messages) { 341 this.connectionHandler.handleMessage(message); 342 } 343 } 344 345 @Override 346 public void handleTransportError(WebSocketSession session, Throwable ex) throws Exception { 347 this.connectionHandler.handleFailure(ex); 348 } 349 350 @Override 351 public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { 352 cancelInactivityTasks(); 353 this.connectionHandler.afterConnectionClosed(); 354 } 355 356 private void cancelInactivityTasks() { 357 for (ScheduledFuture<?> task : this.inactivityTasks) { 358 try { 359 task.cancel(true); 360 } 361 catch (Throwable ex) { 362 // Ignore 363 } 364 } 365 this.lastReadTime = -1; 366 this.lastWriteTime = -1; 367 this.inactivityTasks.clear(); 368 } 369 370 @Override 371 public boolean supportsPartialMessages() { 372 return false; 373 } 374 375 // TcpConnection implementation 376 377 @Override 378 public ListenableFuture<Void> send(Message<byte[]> message) { 379 updateLastWriteTime(); 380 SettableListenableFuture<Void> future = new SettableListenableFuture<Void>(); 381 try { 382 this.session.sendMessage(this.codec.encode(message, this.session.getClass())); 383 future.set(null); 384 } 385 catch (Throwable ex) { 386 future.setException(ex); 387 } 388 finally { 389 updateLastWriteTime(); 390 } 391 return future; 392 } 393 394 private void updateLastWriteTime() { 395 long lastWriteTime = this.lastWriteTime; 396 if (lastWriteTime != -1) { 397 this.lastWriteTime = System.currentTimeMillis(); 398 } 399 } 400 401 @Override 402 public void onReadInactivity(final Runnable runnable, final long duration) { 403 Assert.state(getTaskScheduler() != null, "No TaskScheduler configured"); 404 this.lastReadTime = System.currentTimeMillis(); 405 this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() { 406 @Override 407 public void run() { 408 if (System.currentTimeMillis() - lastReadTime > duration) { 409 try { 410 runnable.run(); 411 } 412 catch (Throwable ex) { 413 if (logger.isDebugEnabled()) { 414 logger.debug("ReadInactivityTask failure", ex); 415 } 416 } 417 } 418 } 419 }, duration / 2)); 420 } 421 422 @Override 423 public void onWriteInactivity(final Runnable runnable, final long duration) { 424 Assert.state(getTaskScheduler() != null, "No TaskScheduler configured"); 425 this.lastWriteTime = System.currentTimeMillis(); 426 this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() { 427 @Override 428 public void run() { 429 if (System.currentTimeMillis() - lastWriteTime > duration) { 430 try { 431 runnable.run(); 432 } 433 catch (Throwable ex) { 434 if (logger.isDebugEnabled()) { 435 logger.debug("WriteInactivityTask failure", ex); 436 } 437 } 438 } 439 } 440 }, duration / 2)); 441 } 442 443 @Override 444 public void close() { 445 try { 446 this.session.close(); 447 } 448 catch (IOException ex) { 449 if (logger.isDebugEnabled()) { 450 logger.debug("Failed to close session: " + this.session.getId(), ex); 451 } 452 } 453 } 454 } 455 456 457 /** 458 * Encode and decode STOMP WebSocket messages. 459 */ 460 private static class StompWebSocketMessageCodec { 461 462 private static final StompEncoder ENCODER = new StompEncoder(); 463 464 private static final StompDecoder DECODER = new StompDecoder(); 465 466 private final BufferingStompDecoder bufferingDecoder; 467 468 public StompWebSocketMessageCodec(int messageSizeLimit) { 469 this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit); 470 } 471 472 public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) { 473 List<Message<byte[]>> result = Collections.<Message<byte[]>>emptyList(); 474 ByteBuffer byteBuffer; 475 if (webSocketMessage instanceof TextMessage) { 476 byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes()); 477 } 478 else if (webSocketMessage instanceof BinaryMessage) { 479 byteBuffer = ((BinaryMessage) webSocketMessage).getPayload(); 480 } 481 else { 482 return result; 483 } 484 result = this.bufferingDecoder.decode(byteBuffer); 485 if (result.isEmpty()) { 486 if (logger.isTraceEnabled()) { 487 logger.trace("Incomplete STOMP frame content received, bufferSize=" + 488 this.bufferingDecoder.getBufferSize() + ", bufferSizeLimit=" + 489 this.bufferingDecoder.getBufferSizeLimit() + "."); 490 } 491 } 492 return result; 493 } 494 495 public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) { 496 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 497 Assert.notNull(accessor, "No StompHeaderAccessor available"); 498 byte[] payload = message.getPayload(); 499 byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload); 500 501 boolean useBinary = (payload.length > 0 && 502 !(SockJsSession.class.isAssignableFrom(sessionType)) && 503 MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType())); 504 505 return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes)); 506 } 507 } 508 509}