001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.messaging.simp.stomp; 018 019import java.lang.reflect.Type; 020import java.util.ArrayList; 021import java.util.Collections; 022import java.util.Date; 023import java.util.List; 024import java.util.Map; 025import java.util.concurrent.ConcurrentHashMap; 026import java.util.concurrent.ExecutionException; 027import java.util.concurrent.ScheduledFuture; 028import java.util.concurrent.atomic.AtomicInteger; 029 030import org.apache.commons.logging.Log; 031import org.apache.commons.logging.LogFactory; 032 033import org.springframework.core.ResolvableType; 034import org.springframework.messaging.Message; 035import org.springframework.messaging.MessageDeliveryException; 036import org.springframework.messaging.converter.MessageConversionException; 037import org.springframework.messaging.converter.MessageConverter; 038import org.springframework.messaging.converter.SimpleMessageConverter; 039import org.springframework.messaging.support.MessageBuilder; 040import org.springframework.messaging.support.MessageHeaderAccessor; 041import org.springframework.messaging.tcp.TcpConnection; 042import org.springframework.scheduling.TaskScheduler; 043import org.springframework.util.AlternativeJdkIdGenerator; 044import org.springframework.util.Assert; 045import org.springframework.util.IdGenerator; 046import org.springframework.util.StringUtils; 047import org.springframework.util.concurrent.ListenableFuture; 048import org.springframework.util.concurrent.ListenableFutureCallback; 049import org.springframework.util.concurrent.SettableListenableFuture; 050 051/** 052 * Default implementation of {@link ConnectionHandlingStompSession}. 053 * 054 * @author Rossen Stoyanchev 055 * @since 4.2 056 */ 057public class DefaultStompSession implements ConnectionHandlingStompSession { 058 059 private static final Log logger = LogFactory.getLog(DefaultStompSession.class); 060 061 private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator(); 062 063 public static final byte[] EMPTY_PAYLOAD = new byte[0]; 064 065 /* STOMP spec: receiver SHOULD take into account an error margin */ 066 private static final long HEARTBEAT_MULTIPLIER = 3; 067 068 private static final Message<byte[]> HEARTBEAT; 069 070 static { 071 StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat(); 072 HEARTBEAT = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders()); 073 } 074 075 076 private final String sessionId; 077 078 private final StompSessionHandler sessionHandler; 079 080 private final StompHeaders connectHeaders; 081 082 private final SettableListenableFuture<StompSession> sessionFuture = 083 new SettableListenableFuture<StompSession>(); 084 085 private MessageConverter converter = new SimpleMessageConverter(); 086 087 private TaskScheduler taskScheduler; 088 089 private long receiptTimeLimit = 15 * 1000; 090 091 private volatile boolean autoReceiptEnabled; 092 093 094 private volatile TcpConnection<byte[]> connection; 095 096 private volatile String version; 097 098 private final AtomicInteger subscriptionIndex = new AtomicInteger(); 099 100 private final Map<String, DefaultSubscription> subscriptions = 101 new ConcurrentHashMap<String, DefaultSubscription>(4); 102 103 private final AtomicInteger receiptIndex = new AtomicInteger(); 104 105 private final Map<String, ReceiptHandler> receiptHandlers = 106 new ConcurrentHashMap<String, ReceiptHandler>(4); 107 108 /* Whether the client is willfully closing the connection */ 109 private volatile boolean closing = false; 110 111 112 /** 113 * Create a new session. 114 * @param sessionHandler the application handler for the session 115 * @param connectHeaders headers for the STOMP CONNECT frame 116 */ 117 public DefaultStompSession(StompSessionHandler sessionHandler, StompHeaders connectHeaders) { 118 Assert.notNull(sessionHandler, "StompSessionHandler must not be null"); 119 Assert.notNull(connectHeaders, "StompHeaders must not be null"); 120 this.sessionId = idGenerator.generateId().toString(); 121 this.sessionHandler = sessionHandler; 122 this.connectHeaders = connectHeaders; 123 } 124 125 126 @Override 127 public String getSessionId() { 128 return this.sessionId; 129 } 130 131 /** 132 * Return the configured session handler. 133 */ 134 public StompSessionHandler getSessionHandler() { 135 return this.sessionHandler; 136 } 137 138 @Override 139 public ListenableFuture<StompSession> getSessionFuture() { 140 return this.sessionFuture; 141 } 142 143 /** 144 * Set the {@link MessageConverter} to use to convert the payload of incoming 145 * and outgoing messages to and from {@code byte[]} based on object type, or 146 * expected object type, and the "content-type" header. 147 * <p>By default, {@link SimpleMessageConverter} is configured. 148 * @param messageConverter the message converter to use 149 */ 150 public void setMessageConverter(MessageConverter messageConverter) { 151 Assert.notNull(messageConverter, "MessageConverter must not be null"); 152 this.converter = messageConverter; 153 } 154 155 /** 156 * Return the configured {@link MessageConverter}. 157 */ 158 public MessageConverter getMessageConverter() { 159 return this.converter; 160 } 161 162 /** 163 * Configure the TaskScheduler to use for receipt tracking. 164 */ 165 public void setTaskScheduler(TaskScheduler taskScheduler) { 166 this.taskScheduler = taskScheduler; 167 } 168 169 /** 170 * Return the configured TaskScheduler to use for receipt tracking. 171 */ 172 public TaskScheduler getTaskScheduler() { 173 return this.taskScheduler; 174 } 175 176 /** 177 * Configure the time in milliseconds before a receipt expires. 178 * <p>By default set to 15,000 (15 seconds). 179 */ 180 public void setReceiptTimeLimit(long receiptTimeLimit) { 181 Assert.isTrue(receiptTimeLimit > 0, "Receipt time limit must be larger than zero"); 182 this.receiptTimeLimit = receiptTimeLimit; 183 } 184 185 /** 186 * Return the configured time limit before a receipt expires. 187 */ 188 public long getReceiptTimeLimit() { 189 return this.receiptTimeLimit; 190 } 191 192 @Override 193 public void setAutoReceipt(boolean autoReceiptEnabled) { 194 this.autoReceiptEnabled = autoReceiptEnabled; 195 } 196 197 /** 198 * Whether receipt headers should be automatically added. 199 */ 200 public boolean isAutoReceiptEnabled() { 201 return this.autoReceiptEnabled; 202 } 203 204 205 @Override 206 public boolean isConnected() { 207 return (this.connection != null); 208 } 209 210 @Override 211 public Receiptable send(String destination, Object payload) { 212 StompHeaders stompHeaders = new StompHeaders(); 213 stompHeaders.setDestination(destination); 214 return send(stompHeaders, payload); 215 } 216 217 @Override 218 public Receiptable send(StompHeaders stompHeaders, Object payload) { 219 Assert.hasText(stompHeaders.getDestination(), "Destination header is required"); 220 221 String receiptId = checkOrAddReceipt(stompHeaders); 222 Receiptable receiptable = new ReceiptHandler(receiptId); 223 224 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SEND); 225 accessor.addNativeHeaders(stompHeaders); 226 Message<byte[]> message = createMessage(accessor, payload); 227 execute(message); 228 229 return receiptable; 230 } 231 232 private String checkOrAddReceipt(StompHeaders stompHeaders) { 233 String receiptId = stompHeaders.getReceipt(); 234 if (isAutoReceiptEnabled() && receiptId == null) { 235 receiptId = String.valueOf(DefaultStompSession.this.receiptIndex.getAndIncrement()); 236 stompHeaders.setReceipt(receiptId); 237 } 238 return receiptId; 239 } 240 241 private StompHeaderAccessor createHeaderAccessor(StompCommand command) { 242 StompHeaderAccessor accessor = StompHeaderAccessor.create(command); 243 accessor.setSessionId(this.sessionId); 244 accessor.setLeaveMutable(true); 245 return accessor; 246 } 247 248 @SuppressWarnings("unchecked") 249 private Message<byte[]> createMessage(StompHeaderAccessor accessor, Object payload) { 250 accessor.updateSimpMessageHeadersFromStompHeaders(); 251 Message<byte[]> message; 252 if (payload == null) { 253 message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); 254 } 255 else if (payload instanceof byte[]) { 256 message = MessageBuilder.createMessage((byte[]) payload, accessor.getMessageHeaders()); 257 } 258 else { 259 message = (Message<byte[]>) getMessageConverter().toMessage(payload, accessor.getMessageHeaders()); 260 accessor.updateStompHeadersFromSimpMessageHeaders(); 261 if (message == null) { 262 throw new MessageConversionException("Unable to convert payload with type='" + 263 payload.getClass().getName() + "', contentType='" + accessor.getContentType() + 264 "', converter=[" + getMessageConverter() + "]"); 265 } 266 } 267 return message; 268 } 269 270 private void execute(Message<byte[]> message) { 271 if (logger.isTraceEnabled()) { 272 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 273 logger.trace("Sending " + accessor.getDetailedLogMessage(message.getPayload())); 274 } 275 TcpConnection<byte[]> conn = this.connection; 276 Assert.state(conn != null, "Connection closed"); 277 try { 278 conn.send(message).get(); 279 } 280 catch (ExecutionException ex) { 281 throw new MessageDeliveryException(message, ex.getCause()); 282 } 283 catch (Throwable ex) { 284 throw new MessageDeliveryException(message, ex); 285 } 286 } 287 288 @Override 289 public Subscription subscribe(String destination, StompFrameHandler handler) { 290 StompHeaders stompHeaders = new StompHeaders(); 291 stompHeaders.setDestination(destination); 292 return subscribe(stompHeaders, handler); 293 } 294 295 @Override 296 public Subscription subscribe(StompHeaders stompHeaders, StompFrameHandler handler) { 297 String destination = stompHeaders.getDestination(); 298 Assert.hasText(destination, "Destination header is required"); 299 Assert.notNull(handler, "StompFrameHandler must not be null"); 300 301 String subscriptionId = stompHeaders.getId(); 302 if (!StringUtils.hasText(subscriptionId)) { 303 subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement()); 304 stompHeaders.setId(subscriptionId); 305 } 306 String receiptId = checkOrAddReceipt(stompHeaders); 307 Subscription subscription = new DefaultSubscription(subscriptionId, destination, receiptId, handler); 308 309 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE); 310 accessor.addNativeHeaders(stompHeaders); 311 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 312 execute(message); 313 314 return subscription; 315 } 316 317 @Override 318 public Receiptable acknowledge(String messageId, boolean consumed) { 319 StompHeaders stompHeaders = new StompHeaders(); 320 if ("1.1".equals(this.version)) { 321 stompHeaders.setMessageId(messageId); 322 } 323 else { 324 stompHeaders.setId(messageId); 325 } 326 327 String receiptId = checkOrAddReceipt(stompHeaders); 328 Receiptable receiptable = new ReceiptHandler(receiptId); 329 330 StompCommand command = (consumed ? StompCommand.ACK : StompCommand.NACK); 331 StompHeaderAccessor accessor = createHeaderAccessor(command); 332 accessor.addNativeHeaders(stompHeaders); 333 Message<byte[]> message = createMessage(accessor, null); 334 execute(message); 335 336 return receiptable; 337 } 338 339 private void unsubscribe(String id) { 340 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); 341 accessor.setSubscriptionId(id); 342 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 343 execute(message); 344 } 345 346 @Override 347 public void disconnect() { 348 this.closing = true; 349 try { 350 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.DISCONNECT); 351 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 352 execute(message); 353 } 354 finally { 355 resetConnection(); 356 } 357 } 358 359 360 // TcpConnectionHandler 361 362 @Override 363 public void afterConnected(TcpConnection<byte[]> connection) { 364 this.connection = connection; 365 if (logger.isDebugEnabled()) { 366 logger.debug("Connection established in session id=" + this.sessionId); 367 } 368 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.CONNECT); 369 accessor.addNativeHeaders(this.connectHeaders); 370 accessor.setAcceptVersion("1.1,1.2"); 371 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 372 execute(message); 373 } 374 375 @Override 376 public void afterConnectFailure(Throwable ex) { 377 if (logger.isDebugEnabled()) { 378 logger.debug("Failed to connect session id=" + this.sessionId, ex); 379 } 380 this.sessionFuture.setException(ex); 381 this.sessionHandler.handleTransportError(this, ex); 382 } 383 384 @Override 385 public void handleMessage(Message<byte[]> message) { 386 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 387 accessor.setSessionId(this.sessionId); 388 StompCommand command = accessor.getCommand(); 389 Map<String, List<String>> nativeHeaders = accessor.getNativeHeaders(); 390 StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(nativeHeaders); 391 boolean isHeartbeat = accessor.isHeartbeat(); 392 if (logger.isTraceEnabled()) { 393 logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload())); 394 } 395 try { 396 if (StompCommand.MESSAGE.equals(command)) { 397 DefaultSubscription subscription = this.subscriptions.get(stompHeaders.getSubscription()); 398 if (subscription != null) { 399 invokeHandler(subscription.getHandler(), message, stompHeaders); 400 } 401 else if (logger.isDebugEnabled()) { 402 logger.debug("No handler for: " + accessor.getDetailedLogMessage(message.getPayload()) + 403 ". Perhaps just unsubscribed?"); 404 } 405 } 406 else { 407 if (StompCommand.RECEIPT.equals(command)) { 408 String receiptId = stompHeaders.getReceiptId(); 409 ReceiptHandler handler = this.receiptHandlers.get(receiptId); 410 if (handler != null) { 411 handler.handleReceiptReceived(); 412 } 413 else if (logger.isDebugEnabled()) { 414 logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload())); 415 } 416 } 417 else if (StompCommand.CONNECTED.equals(command)) { 418 initHeartbeatTasks(stompHeaders); 419 this.version = stompHeaders.getFirst("version"); 420 this.sessionFuture.set(this); 421 this.sessionHandler.afterConnected(this, stompHeaders); 422 } 423 else if (StompCommand.ERROR.equals(command)) { 424 invokeHandler(this.sessionHandler, message, stompHeaders); 425 } 426 else if (!isHeartbeat && logger.isTraceEnabled()) { 427 logger.trace("Message not handled."); 428 } 429 } 430 } 431 catch (Throwable ex) { 432 this.sessionHandler.handleException(this, command, stompHeaders, message.getPayload(), ex); 433 } 434 } 435 436 private void invokeHandler(StompFrameHandler handler, Message<byte[]> message, StompHeaders stompHeaders) { 437 if (message.getPayload().length == 0) { 438 handler.handleFrame(stompHeaders, null); 439 return; 440 } 441 Type type = handler.getPayloadType(stompHeaders); 442 Class<?> payloadType = ResolvableType.forType(type).resolve(); 443 Object object = getMessageConverter().fromMessage(message, payloadType); 444 if (object == null) { 445 throw new MessageConversionException("No suitable converter, payloadType=" + payloadType + 446 ", handlerType=" + handler.getClass()); 447 } 448 handler.handleFrame(stompHeaders, object); 449 } 450 451 private void initHeartbeatTasks(StompHeaders connectedHeaders) { 452 long[] connect = this.connectHeaders.getHeartbeat(); 453 long[] connected = connectedHeaders.getHeartbeat(); 454 if (connect == null || connected == null) { 455 return; 456 } 457 if (connect[0] > 0 && connected[1] > 0) { 458 long interval = Math.max(connect[0], connected[1]); 459 this.connection.onWriteInactivity(new WriteInactivityTask(), interval); 460 } 461 if (connect[1] > 0 && connected[0] > 0) { 462 final long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER; 463 this.connection.onReadInactivity(new ReadInactivityTask(), interval); 464 } 465 } 466 467 @Override 468 public void handleFailure(Throwable ex) { 469 try { 470 this.sessionFuture.setException(ex); // no-op if already set 471 this.sessionHandler.handleTransportError(this, ex); 472 } 473 catch (Throwable ex2) { 474 if (logger.isDebugEnabled()) { 475 logger.debug("Uncaught failure while handling transport failure", ex2); 476 } 477 } 478 } 479 480 @Override 481 public void afterConnectionClosed() { 482 if (logger.isDebugEnabled()) { 483 logger.debug("Connection closed in session id=" + this.sessionId); 484 } 485 if (!this.closing) { 486 resetConnection(); 487 handleFailure(new ConnectionLostException("Connection closed")); 488 } 489 } 490 491 private void resetConnection() { 492 TcpConnection<?> conn = this.connection; 493 this.connection = null; 494 if (conn != null) { 495 try { 496 conn.close(); 497 } 498 catch (Throwable ex) { 499 // ignore 500 } 501 } 502 } 503 504 505 private class ReceiptHandler implements Receiptable { 506 507 private final String receiptId; 508 509 private final List<Runnable> receiptCallbacks = new ArrayList<Runnable>(2); 510 511 private final List<Runnable> receiptLostCallbacks = new ArrayList<Runnable>(2); 512 513 private ScheduledFuture<?> future; 514 515 private Boolean result; 516 517 public ReceiptHandler(String receiptId) { 518 this.receiptId = receiptId; 519 if (this.receiptId != null) { 520 initReceiptHandling(); 521 } 522 } 523 524 private void initReceiptHandling() { 525 Assert.notNull(getTaskScheduler(), "To track receipts, a TaskScheduler must be configured"); 526 DefaultStompSession.this.receiptHandlers.put(this.receiptId, this); 527 Date startTime = new Date(System.currentTimeMillis() + getReceiptTimeLimit()); 528 this.future = getTaskScheduler().schedule(new Runnable() { 529 @Override 530 public void run() { 531 handleReceiptNotReceived(); 532 } 533 }, startTime); 534 } 535 536 @Override 537 public String getReceiptId() { 538 return this.receiptId; 539 } 540 541 @Override 542 public void addReceiptTask(Runnable task) { 543 addTask(task, true); 544 } 545 546 @Override 547 public void addReceiptLostTask(Runnable task) { 548 addTask(task, false); 549 } 550 551 private void addTask(Runnable task, boolean successTask) { 552 Assert.notNull(this.receiptId, 553 "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header"); 554 synchronized (this) { 555 if (this.result != null && this.result == successTask) { 556 invoke(Collections.singletonList(task)); 557 } 558 else { 559 if (successTask) { 560 this.receiptCallbacks.add(task); 561 } 562 else { 563 this.receiptLostCallbacks.add(task); 564 } 565 } 566 } 567 } 568 569 private void invoke(List<Runnable> callbacks) { 570 for (Runnable runnable : callbacks) { 571 try { 572 runnable.run(); 573 } 574 catch (Throwable ex) { 575 // ignore 576 } 577 } 578 } 579 580 public void handleReceiptReceived() { 581 handleInternal(true); 582 } 583 584 public void handleReceiptNotReceived() { 585 handleInternal(false); 586 } 587 588 private void handleInternal(boolean result) { 589 synchronized (this) { 590 if (this.result != null) { 591 return; 592 } 593 this.result = result; 594 invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks); 595 DefaultStompSession.this.receiptHandlers.remove(this.receiptId); 596 if (this.future != null) { 597 this.future.cancel(true); 598 } 599 } 600 } 601 } 602 603 604 private class DefaultSubscription extends ReceiptHandler implements Subscription { 605 606 private final String id; 607 608 private final String destination; 609 610 private final StompFrameHandler handler; 611 612 public DefaultSubscription(String id, String destination, String receiptId, StompFrameHandler handler) { 613 super(receiptId); 614 Assert.notNull(destination, "Destination must not be null"); 615 Assert.notNull(handler, "StompFrameHandler must not be null"); 616 this.id = id; 617 this.destination = destination; 618 this.handler = handler; 619 DefaultStompSession.this.subscriptions.put(id, this); 620 } 621 622 @Override 623 public String getSubscriptionId() { 624 return this.id; 625 } 626 627 public String getDestination() { 628 return this.destination; 629 } 630 631 public StompFrameHandler getHandler() { 632 return this.handler; 633 } 634 635 @Override 636 public void unsubscribe() { 637 DefaultStompSession.this.subscriptions.remove(getSubscriptionId()); 638 DefaultStompSession.this.unsubscribe(getSubscriptionId()); 639 } 640 641 @Override 642 public String toString() { 643 return "Subscription [id=" + getSubscriptionId() + ", destination='" + getDestination() + 644 "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]"; 645 } 646 } 647 648 649 private class WriteInactivityTask implements Runnable { 650 651 @Override 652 public void run() { 653 TcpConnection<byte[]> conn = connection; 654 if (conn != null) { 655 conn.send(HEARTBEAT).addCallback( 656 new ListenableFutureCallback<Void>() { 657 @Override 658 public void onSuccess(Void result) { 659 } 660 @Override 661 public void onFailure(Throwable ex) { 662 handleFailure(ex); 663 } 664 }); 665 } 666 } 667 } 668 669 670 private class ReadInactivityTask implements Runnable { 671 672 @Override 673 public void run() { 674 closing = true; 675 String error = "Server has gone quiet. Closing connection in session id=" + sessionId + "."; 676 if (logger.isDebugEnabled()) { 677 logger.debug(error); 678 } 679 resetConnection(); 680 handleFailure(new IllegalStateException(error)); 681 } 682 } 683 684}