001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.messaging.simp.stomp; 018 019import java.lang.reflect.Type; 020import java.util.ArrayList; 021import java.util.Collections; 022import java.util.Date; 023import java.util.List; 024import java.util.Map; 025import java.util.concurrent.ConcurrentHashMap; 026import java.util.concurrent.ExecutionException; 027import java.util.concurrent.ScheduledFuture; 028import java.util.concurrent.TimeUnit; 029import java.util.concurrent.atomic.AtomicInteger; 030 031import org.apache.commons.logging.Log; 032 033import org.springframework.core.ResolvableType; 034import org.springframework.lang.Nullable; 035import org.springframework.messaging.Message; 036import org.springframework.messaging.MessageDeliveryException; 037import org.springframework.messaging.converter.MessageConversionException; 038import org.springframework.messaging.converter.MessageConverter; 039import org.springframework.messaging.converter.SimpleMessageConverter; 040import org.springframework.messaging.simp.SimpLogging; 041import org.springframework.messaging.support.MessageBuilder; 042import org.springframework.messaging.support.MessageHeaderAccessor; 043import org.springframework.messaging.tcp.TcpConnection; 044import org.springframework.scheduling.TaskScheduler; 045import org.springframework.util.AlternativeJdkIdGenerator; 046import org.springframework.util.Assert; 047import org.springframework.util.IdGenerator; 048import org.springframework.util.StringUtils; 049import org.springframework.util.concurrent.ListenableFuture; 050import org.springframework.util.concurrent.ListenableFutureCallback; 051import org.springframework.util.concurrent.SettableListenableFuture; 052 053/** 054 * Default implementation of {@link ConnectionHandlingStompSession}. 055 * 056 * @author Rossen Stoyanchev 057 * @since 4.2 058 */ 059public class DefaultStompSession implements ConnectionHandlingStompSession { 060 061 private static final Log logger = SimpLogging.forLogName(DefaultStompSession.class); 062 063 private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator(); 064 065 /** 066 * An empty payload. 067 */ 068 public static final byte[] EMPTY_PAYLOAD = new byte[0]; 069 070 /* STOMP spec: receiver SHOULD take into account an error margin */ 071 private static final long HEARTBEAT_MULTIPLIER = 3; 072 073 private static final Message<byte[]> HEARTBEAT; 074 075 static { 076 StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat(); 077 HEARTBEAT = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders()); 078 } 079 080 081 private final String sessionId; 082 083 private final StompSessionHandler sessionHandler; 084 085 private final StompHeaders connectHeaders; 086 087 private final SettableListenableFuture<StompSession> sessionFuture = new SettableListenableFuture<>(); 088 089 private MessageConverter converter = new SimpleMessageConverter(); 090 091 @Nullable 092 private TaskScheduler taskScheduler; 093 094 private long receiptTimeLimit = TimeUnit.SECONDS.toMillis(15); 095 096 private volatile boolean autoReceiptEnabled; 097 098 099 @Nullable 100 private volatile TcpConnection<byte[]> connection; 101 102 @Nullable 103 private volatile String version; 104 105 private final AtomicInteger subscriptionIndex = new AtomicInteger(); 106 107 private final Map<String, DefaultSubscription> subscriptions = new ConcurrentHashMap<>(4); 108 109 private final AtomicInteger receiptIndex = new AtomicInteger(); 110 111 private final Map<String, ReceiptHandler> receiptHandlers = new ConcurrentHashMap<>(4); 112 113 /* Whether the client is willfully closing the connection */ 114 private volatile boolean closing; 115 116 117 /** 118 * Create a new session. 119 * @param sessionHandler the application handler for the session 120 * @param connectHeaders headers for the STOMP CONNECT frame 121 */ 122 public DefaultStompSession(StompSessionHandler sessionHandler, StompHeaders connectHeaders) { 123 Assert.notNull(sessionHandler, "StompSessionHandler must not be null"); 124 Assert.notNull(connectHeaders, "StompHeaders must not be null"); 125 this.sessionId = idGenerator.generateId().toString(); 126 this.sessionHandler = sessionHandler; 127 this.connectHeaders = connectHeaders; 128 } 129 130 131 @Override 132 public String getSessionId() { 133 return this.sessionId; 134 } 135 136 /** 137 * Return the configured session handler. 138 */ 139 public StompSessionHandler getSessionHandler() { 140 return this.sessionHandler; 141 } 142 143 @Override 144 public ListenableFuture<StompSession> getSessionFuture() { 145 return this.sessionFuture; 146 } 147 148 /** 149 * Set the {@link MessageConverter} to use to convert the payload of incoming 150 * and outgoing messages to and from {@code byte[]} based on object type, or 151 * expected object type, and the "content-type" header. 152 * <p>By default, {@link SimpleMessageConverter} is configured. 153 * @param messageConverter the message converter to use 154 */ 155 public void setMessageConverter(MessageConverter messageConverter) { 156 Assert.notNull(messageConverter, "MessageConverter must not be null"); 157 this.converter = messageConverter; 158 } 159 160 /** 161 * Return the configured {@link MessageConverter}. 162 */ 163 public MessageConverter getMessageConverter() { 164 return this.converter; 165 } 166 167 /** 168 * Configure the TaskScheduler to use for receipt tracking. 169 */ 170 public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) { 171 this.taskScheduler = taskScheduler; 172 } 173 174 /** 175 * Return the configured TaskScheduler to use for receipt tracking. 176 */ 177 @Nullable 178 public TaskScheduler getTaskScheduler() { 179 return this.taskScheduler; 180 } 181 182 /** 183 * Configure the time in milliseconds before a receipt expires. 184 * <p>By default set to 15,000 (15 seconds). 185 */ 186 public void setReceiptTimeLimit(long receiptTimeLimit) { 187 Assert.isTrue(receiptTimeLimit > 0, "Receipt time limit must be larger than zero"); 188 this.receiptTimeLimit = receiptTimeLimit; 189 } 190 191 /** 192 * Return the configured time limit before a receipt expires. 193 */ 194 public long getReceiptTimeLimit() { 195 return this.receiptTimeLimit; 196 } 197 198 @Override 199 public void setAutoReceipt(boolean autoReceiptEnabled) { 200 this.autoReceiptEnabled = autoReceiptEnabled; 201 } 202 203 /** 204 * Whether receipt headers should be automatically added. 205 */ 206 public boolean isAutoReceiptEnabled() { 207 return this.autoReceiptEnabled; 208 } 209 210 211 @Override 212 public boolean isConnected() { 213 return (this.connection != null); 214 } 215 216 @Override 217 public Receiptable send(String destination, Object payload) { 218 StompHeaders headers = new StompHeaders(); 219 headers.setDestination(destination); 220 return send(headers, payload); 221 } 222 223 @Override 224 public Receiptable send(StompHeaders headers, Object payload) { 225 Assert.hasText(headers.getDestination(), "Destination header is required"); 226 227 String receiptId = checkOrAddReceipt(headers); 228 Receiptable receiptable = new ReceiptHandler(receiptId); 229 230 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SEND); 231 accessor.addNativeHeaders(headers); 232 Message<byte[]> message = createMessage(accessor, payload); 233 execute(message); 234 235 return receiptable; 236 } 237 238 @Nullable 239 private String checkOrAddReceipt(StompHeaders headers) { 240 String receiptId = headers.getReceipt(); 241 if (isAutoReceiptEnabled() && receiptId == null) { 242 receiptId = String.valueOf(DefaultStompSession.this.receiptIndex.getAndIncrement()); 243 headers.setReceipt(receiptId); 244 } 245 return receiptId; 246 } 247 248 private StompHeaderAccessor createHeaderAccessor(StompCommand command) { 249 StompHeaderAccessor accessor = StompHeaderAccessor.create(command); 250 accessor.setSessionId(this.sessionId); 251 accessor.setLeaveMutable(true); 252 return accessor; 253 } 254 255 @SuppressWarnings("unchecked") 256 private Message<byte[]> createMessage(StompHeaderAccessor accessor, @Nullable Object payload) { 257 accessor.updateSimpMessageHeadersFromStompHeaders(); 258 Message<byte[]> message; 259 if (StringUtils.isEmpty(payload) || (payload instanceof byte[] && ((byte[]) payload).length == 0)) { 260 message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); 261 } 262 else { 263 message = (Message<byte[]>) getMessageConverter().toMessage(payload, accessor.getMessageHeaders()); 264 accessor.updateStompHeadersFromSimpMessageHeaders(); 265 if (message == null) { 266 throw new MessageConversionException("Unable to convert payload with type='" + 267 payload.getClass().getName() + "', contentType='" + accessor.getContentType() + 268 "', converter=[" + getMessageConverter() + "]"); 269 } 270 } 271 return message; 272 } 273 274 private void execute(Message<byte[]> message) { 275 if (logger.isTraceEnabled()) { 276 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 277 if (accessor != null) { 278 logger.trace("Sending " + accessor.getDetailedLogMessage(message.getPayload())); 279 } 280 } 281 TcpConnection<byte[]> conn = this.connection; 282 Assert.state(conn != null, "Connection closed"); 283 try { 284 conn.send(message).get(); 285 } 286 catch (ExecutionException ex) { 287 throw new MessageDeliveryException(message, ex.getCause()); 288 } 289 catch (Throwable ex) { 290 throw new MessageDeliveryException(message, ex); 291 } 292 } 293 294 @Override 295 public Subscription subscribe(String destination, StompFrameHandler handler) { 296 StompHeaders headers = new StompHeaders(); 297 headers.setDestination(destination); 298 return subscribe(headers, handler); 299 } 300 301 @Override 302 public Subscription subscribe(StompHeaders headers, StompFrameHandler handler) { 303 Assert.hasText(headers.getDestination(), "Destination header is required"); 304 Assert.notNull(handler, "StompFrameHandler must not be null"); 305 306 String subscriptionId = headers.getId(); 307 if (!StringUtils.hasText(subscriptionId)) { 308 subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement()); 309 headers.setId(subscriptionId); 310 } 311 checkOrAddReceipt(headers); 312 Subscription subscription = new DefaultSubscription(headers, handler); 313 314 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE); 315 accessor.addNativeHeaders(headers); 316 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 317 execute(message); 318 319 return subscription; 320 } 321 322 @Override 323 public Receiptable acknowledge(String messageId, boolean consumed) { 324 StompHeaders headers = new StompHeaders(); 325 if ("1.1".equals(this.version)) { 326 headers.setMessageId(messageId); 327 } 328 else { 329 headers.setId(messageId); 330 } 331 return acknowledge(headers, consumed); 332 } 333 334 @Override 335 public Receiptable acknowledge(StompHeaders headers, boolean consumed) { 336 String receiptId = checkOrAddReceipt(headers); 337 Receiptable receiptable = new ReceiptHandler(receiptId); 338 339 StompCommand command = (consumed ? StompCommand.ACK : StompCommand.NACK); 340 StompHeaderAccessor accessor = createHeaderAccessor(command); 341 accessor.addNativeHeaders(headers); 342 Message<byte[]> message = createMessage(accessor, null); 343 execute(message); 344 345 return receiptable; 346 } 347 348 private void unsubscribe(String id, @Nullable StompHeaders headers) { 349 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); 350 if (headers != null) { 351 accessor.addNativeHeaders(headers); 352 } 353 accessor.setSubscriptionId(id); 354 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 355 execute(message); 356 } 357 358 @Override 359 public void disconnect() { 360 disconnect(null); 361 } 362 363 @Override 364 public void disconnect(@Nullable StompHeaders headers) { 365 this.closing = true; 366 try { 367 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.DISCONNECT); 368 if (headers != null) { 369 accessor.addNativeHeaders(headers); 370 } 371 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 372 execute(message); 373 } 374 finally { 375 resetConnection(); 376 } 377 } 378 379 380 // TcpConnectionHandler 381 382 @Override 383 public void afterConnected(TcpConnection<byte[]> connection) { 384 this.connection = connection; 385 if (logger.isDebugEnabled()) { 386 logger.debug("Connection established in session id=" + this.sessionId); 387 } 388 StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.CONNECT); 389 accessor.addNativeHeaders(this.connectHeaders); 390 if (this.connectHeaders.getAcceptVersion() == null) { 391 accessor.setAcceptVersion("1.1,1.2"); 392 } 393 Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD); 394 execute(message); 395 } 396 397 @Override 398 public void afterConnectFailure(Throwable ex) { 399 if (logger.isDebugEnabled()) { 400 logger.debug("Failed to connect session id=" + this.sessionId, ex); 401 } 402 this.sessionFuture.setException(ex); 403 this.sessionHandler.handleTransportError(this, ex); 404 } 405 406 @Override 407 public void handleMessage(Message<byte[]> message) { 408 StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); 409 Assert.state(accessor != null, "No StompHeaderAccessor"); 410 411 accessor.setSessionId(this.sessionId); 412 StompCommand command = accessor.getCommand(); 413 Map<String, List<String>> nativeHeaders = accessor.getNativeHeaders(); 414 StompHeaders headers = StompHeaders.readOnlyStompHeaders(nativeHeaders); 415 boolean isHeartbeat = accessor.isHeartbeat(); 416 if (logger.isTraceEnabled()) { 417 logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload())); 418 } 419 420 try { 421 if (StompCommand.MESSAGE.equals(command)) { 422 DefaultSubscription subscription = this.subscriptions.get(headers.getSubscription()); 423 if (subscription != null) { 424 invokeHandler(subscription.getHandler(), message, headers); 425 } 426 else if (logger.isDebugEnabled()) { 427 logger.debug("No handler for: " + accessor.getDetailedLogMessage(message.getPayload()) + 428 ". Perhaps just unsubscribed?"); 429 } 430 } 431 else { 432 if (StompCommand.RECEIPT.equals(command)) { 433 String receiptId = headers.getReceiptId(); 434 ReceiptHandler handler = this.receiptHandlers.get(receiptId); 435 if (handler != null) { 436 handler.handleReceiptReceived(); 437 } 438 else if (logger.isDebugEnabled()) { 439 logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload())); 440 } 441 } 442 else if (StompCommand.CONNECTED.equals(command)) { 443 initHeartbeatTasks(headers); 444 this.version = headers.getFirst("version"); 445 this.sessionFuture.set(this); 446 this.sessionHandler.afterConnected(this, headers); 447 } 448 else if (StompCommand.ERROR.equals(command)) { 449 invokeHandler(this.sessionHandler, message, headers); 450 } 451 else if (!isHeartbeat && logger.isTraceEnabled()) { 452 logger.trace("Message not handled."); 453 } 454 } 455 } 456 catch (Throwable ex) { 457 this.sessionHandler.handleException(this, command, headers, message.getPayload(), ex); 458 } 459 } 460 461 private void invokeHandler(StompFrameHandler handler, Message<byte[]> message, StompHeaders headers) { 462 if (message.getPayload().length == 0) { 463 handler.handleFrame(headers, null); 464 return; 465 } 466 Type payloadType = handler.getPayloadType(headers); 467 Class<?> resolvedType = ResolvableType.forType(payloadType).resolve(); 468 if (resolvedType == null) { 469 throw new MessageConversionException("Unresolvable payload type [" + payloadType + 470 "] from handler type [" + handler.getClass() + "]"); 471 } 472 Object object = getMessageConverter().fromMessage(message, resolvedType); 473 if (object == null) { 474 throw new MessageConversionException("No suitable converter for payload type [" + payloadType + 475 "] from handler type [" + handler.getClass() + "]"); 476 } 477 handler.handleFrame(headers, object); 478 } 479 480 private void initHeartbeatTasks(StompHeaders connectedHeaders) { 481 long[] connect = this.connectHeaders.getHeartbeat(); 482 long[] connected = connectedHeaders.getHeartbeat(); 483 if (connect == null || connected == null) { 484 return; 485 } 486 TcpConnection<byte[]> con = this.connection; 487 Assert.state(con != null, "No TcpConnection available"); 488 if (connect[0] > 0 && connected[1] > 0) { 489 long interval = Math.max(connect[0], connected[1]); 490 con.onWriteInactivity(new WriteInactivityTask(), interval); 491 } 492 if (connect[1] > 0 && connected[0] > 0) { 493 long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER; 494 con.onReadInactivity(new ReadInactivityTask(), interval); 495 } 496 } 497 498 @Override 499 public void handleFailure(Throwable ex) { 500 try { 501 this.sessionFuture.setException(ex); // no-op if already set 502 this.sessionHandler.handleTransportError(this, ex); 503 } 504 catch (Throwable ex2) { 505 if (logger.isDebugEnabled()) { 506 logger.debug("Uncaught failure while handling transport failure", ex2); 507 } 508 } 509 } 510 511 @Override 512 public void afterConnectionClosed() { 513 if (logger.isDebugEnabled()) { 514 logger.debug("Connection closed in session id=" + this.sessionId); 515 } 516 if (!this.closing) { 517 resetConnection(); 518 handleFailure(new ConnectionLostException("Connection closed")); 519 } 520 } 521 522 private void resetConnection() { 523 TcpConnection<?> conn = this.connection; 524 this.connection = null; 525 if (conn != null) { 526 try { 527 conn.close(); 528 } 529 catch (Throwable ex) { 530 // ignore 531 } 532 } 533 } 534 535 536 private class ReceiptHandler implements Receiptable { 537 538 @Nullable 539 private final String receiptId; 540 541 private final List<Runnable> receiptCallbacks = new ArrayList<>(2); 542 543 private final List<Runnable> receiptLostCallbacks = new ArrayList<>(2); 544 545 @Nullable 546 private ScheduledFuture<?> future; 547 548 @Nullable 549 private Boolean result; 550 551 public ReceiptHandler(@Nullable String receiptId) { 552 this.receiptId = receiptId; 553 if (receiptId != null) { 554 initReceiptHandling(); 555 } 556 } 557 558 private void initReceiptHandling() { 559 Assert.notNull(getTaskScheduler(), "To track receipts, a TaskScheduler must be configured"); 560 DefaultStompSession.this.receiptHandlers.put(this.receiptId, this); 561 Date startTime = new Date(System.currentTimeMillis() + getReceiptTimeLimit()); 562 this.future = getTaskScheduler().schedule(this::handleReceiptNotReceived, startTime); 563 } 564 565 @Override 566 @Nullable 567 public String getReceiptId() { 568 return this.receiptId; 569 } 570 571 @Override 572 public void addReceiptTask(Runnable task) { 573 addTask(task, true); 574 } 575 576 @Override 577 public void addReceiptLostTask(Runnable task) { 578 addTask(task, false); 579 } 580 581 private void addTask(Runnable task, boolean successTask) { 582 Assert.notNull(this.receiptId, 583 "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header"); 584 synchronized (this) { 585 if (this.result != null && this.result == successTask) { 586 invoke(Collections.singletonList(task)); 587 } 588 else { 589 if (successTask) { 590 this.receiptCallbacks.add(task); 591 } 592 else { 593 this.receiptLostCallbacks.add(task); 594 } 595 } 596 } 597 } 598 599 private void invoke(List<Runnable> callbacks) { 600 for (Runnable runnable : callbacks) { 601 try { 602 runnable.run(); 603 } 604 catch (Throwable ex) { 605 // ignore 606 } 607 } 608 } 609 610 public void handleReceiptReceived() { 611 handleInternal(true); 612 } 613 614 public void handleReceiptNotReceived() { 615 handleInternal(false); 616 } 617 618 private void handleInternal(boolean result) { 619 synchronized (this) { 620 if (this.result != null) { 621 return; 622 } 623 this.result = result; 624 invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks); 625 DefaultStompSession.this.receiptHandlers.remove(this.receiptId); 626 if (this.future != null) { 627 this.future.cancel(true); 628 } 629 } 630 } 631 } 632 633 634 private class DefaultSubscription extends ReceiptHandler implements Subscription { 635 636 private final StompHeaders headers; 637 638 private final StompFrameHandler handler; 639 640 public DefaultSubscription(StompHeaders headers, StompFrameHandler handler) { 641 super(headers.getReceipt()); 642 Assert.notNull(headers.getDestination(), "Destination must not be null"); 643 Assert.notNull(handler, "StompFrameHandler must not be null"); 644 this.headers = headers; 645 this.handler = handler; 646 DefaultStompSession.this.subscriptions.put(headers.getId(), this); 647 } 648 649 @Override 650 @Nullable 651 public String getSubscriptionId() { 652 return this.headers.getId(); 653 } 654 655 @Override 656 public StompHeaders getSubscriptionHeaders() { 657 return this.headers; 658 } 659 660 public StompFrameHandler getHandler() { 661 return this.handler; 662 } 663 664 @Override 665 public void unsubscribe() { 666 unsubscribe(null); 667 } 668 669 @Override 670 public void unsubscribe(@Nullable StompHeaders headers) { 671 String id = this.headers.getId(); 672 if (id != null) { 673 DefaultStompSession.this.subscriptions.remove(id); 674 DefaultStompSession.this.unsubscribe(id, headers); 675 } 676 } 677 678 @Override 679 public String toString() { 680 return "Subscription [id=" + getSubscriptionId() + 681 ", destination='" + this.headers.getDestination() + 682 "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]"; 683 } 684 } 685 686 687 private class WriteInactivityTask implements Runnable { 688 689 @Override 690 public void run() { 691 TcpConnection<byte[]> conn = connection; 692 if (conn != null) { 693 conn.send(HEARTBEAT).addCallback( 694 new ListenableFutureCallback<Void>() { 695 @Override 696 public void onSuccess(@Nullable Void result) { 697 } 698 @Override 699 public void onFailure(Throwable ex) { 700 handleFailure(ex); 701 } 702 }); 703 } 704 } 705 } 706 707 708 private class ReadInactivityTask implements Runnable { 709 710 @Override 711 public void run() { 712 closing = true; 713 String error = "Server has gone quiet. Closing connection in session id=" + sessionId + "."; 714 if (logger.isDebugEnabled()) { 715 logger.debug(error); 716 } 717 resetConnection(); 718 handleFailure(new IllegalStateException(error)); 719 } 720 } 721 722}