001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.messaging.simp.broker; 018 019import java.security.Principal; 020import java.util.Arrays; 021import java.util.Collection; 022import java.util.Map; 023import java.util.concurrent.ConcurrentHashMap; 024import java.util.concurrent.ScheduledFuture; 025 026import org.springframework.lang.Nullable; 027import org.springframework.messaging.Message; 028import org.springframework.messaging.MessageChannel; 029import org.springframework.messaging.MessageHeaders; 030import org.springframework.messaging.SubscribableChannel; 031import org.springframework.messaging.simp.SimpMessageHeaderAccessor; 032import org.springframework.messaging.simp.SimpMessageType; 033import org.springframework.messaging.support.MessageBuilder; 034import org.springframework.messaging.support.MessageHeaderAccessor; 035import org.springframework.messaging.support.MessageHeaderInitializer; 036import org.springframework.scheduling.TaskScheduler; 037import org.springframework.util.Assert; 038import org.springframework.util.MultiValueMap; 039import org.springframework.util.PathMatcher; 040 041/** 042 * A "simple" message broker that recognizes the message types defined in 043 * {@link SimpMessageType}, keeps track of subscriptions with the help of a 044 * {@link SubscriptionRegistry} and sends messages to subscribers. 045 * 046 * @author Rossen Stoyanchev 047 * @author Juergen Hoeller 048 * @since 4.0 049 */ 050public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { 051 052 private static final byte[] EMPTY_PAYLOAD = new byte[0]; 053 054 055 @Nullable 056 private PathMatcher pathMatcher; 057 058 @Nullable 059 private Integer cacheLimit; 060 061 @Nullable 062 private String selectorHeaderName = "selector"; 063 064 @Nullable 065 private TaskScheduler taskScheduler; 066 067 @Nullable 068 private long[] heartbeatValue; 069 070 @Nullable 071 private MessageHeaderInitializer headerInitializer; 072 073 074 private SubscriptionRegistry subscriptionRegistry; 075 076 private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>(); 077 078 @Nullable 079 private ScheduledFuture<?> heartbeatFuture; 080 081 082 /** 083 * Create a SimpleBrokerMessageHandler instance with the given message channels 084 * and destination prefixes. 085 * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) 086 * @param clientOutboundChannel the channel for sending messages to clients (e.g. WebSocket clients) 087 * @param brokerChannel the channel for the application to send messages to the broker 088 * @param destinationPrefixes prefixes to use to filter out messages 089 */ 090 public SimpleBrokerMessageHandler(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel, 091 SubscribableChannel brokerChannel, Collection<String> destinationPrefixes) { 092 093 super(clientInboundChannel, clientOutboundChannel, brokerChannel, destinationPrefixes); 094 this.subscriptionRegistry = new DefaultSubscriptionRegistry(); 095 } 096 097 098 /** 099 * Configure a custom SubscriptionRegistry to use for storing subscriptions. 100 * <p><strong>Note</strong> that when a custom PathMatcher is configured via 101 * {@link #setPathMatcher}, if the custom registry is not an instance of 102 * {@link DefaultSubscriptionRegistry}, the provided PathMatcher is not used 103 * and must be configured directly on the custom registry. 104 */ 105 public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) { 106 Assert.notNull(subscriptionRegistry, "SubscriptionRegistry must not be null"); 107 this.subscriptionRegistry = subscriptionRegistry; 108 initPathMatcherToUse(); 109 initCacheLimitToUse(); 110 initSelectorHeaderNameToUse(); 111 } 112 113 public SubscriptionRegistry getSubscriptionRegistry() { 114 return this.subscriptionRegistry; 115 } 116 117 /** 118 * When configured, the given PathMatcher is passed down to the underlying 119 * SubscriptionRegistry to use for matching destination to subscriptions. 120 * <p>Default is a standard {@link org.springframework.util.AntPathMatcher}. 121 * @since 4.1 122 * @see #setSubscriptionRegistry 123 * @see DefaultSubscriptionRegistry#setPathMatcher 124 * @see org.springframework.util.AntPathMatcher 125 */ 126 public void setPathMatcher(@Nullable PathMatcher pathMatcher) { 127 this.pathMatcher = pathMatcher; 128 initPathMatcherToUse(); 129 } 130 131 private void initPathMatcherToUse() { 132 if (this.pathMatcher != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) { 133 ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setPathMatcher(this.pathMatcher); 134 } 135 } 136 137 /** 138 * When configured, the specified cache limit is passed down to the 139 * underlying SubscriptionRegistry, overriding any default there. 140 * <p>With a standard {@link DefaultSubscriptionRegistry}, the default 141 * cache limit is 1024. 142 * @since 4.3.2 143 * @see #setSubscriptionRegistry 144 * @see DefaultSubscriptionRegistry#setCacheLimit 145 * @see DefaultSubscriptionRegistry#DEFAULT_CACHE_LIMIT 146 */ 147 public void setCacheLimit(@Nullable Integer cacheLimit) { 148 this.cacheLimit = cacheLimit; 149 initCacheLimitToUse(); 150 } 151 152 private void initCacheLimitToUse() { 153 if (this.cacheLimit != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) { 154 ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setCacheLimit(this.cacheLimit); 155 } 156 } 157 158 /** 159 * Configure the name of a header that a subscription message can have for 160 * the purpose of filtering messages matched to the subscription. The header 161 * value is expected to be a Spring EL boolean expression to be applied to 162 * the headers of messages matched to the subscription. 163 * <p>For example: 164 * <pre> 165 * headers.foo == 'bar' 166 * </pre> 167 * <p>By default this is set to "selector". You can set it to a different 168 * name, or to {@code null} to turn off support for a selector header. 169 * @param selectorHeaderName the name to use for a selector header 170 * @since 4.3.17 171 * @see #setSubscriptionRegistry 172 * @see DefaultSubscriptionRegistry#setSelectorHeaderName(String) 173 */ 174 public void setSelectorHeaderName(@Nullable String selectorHeaderName) { 175 this.selectorHeaderName = selectorHeaderName; 176 initSelectorHeaderNameToUse(); 177 } 178 179 private void initSelectorHeaderNameToUse() { 180 if (this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) { 181 ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setSelectorHeaderName(this.selectorHeaderName); 182 } 183 } 184 185 /** 186 * Configure the {@link org.springframework.scheduling.TaskScheduler} to 187 * use for providing heartbeat support. Setting this property also sets the 188 * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000". 189 * <p>By default this is not set. 190 * @since 4.2 191 */ 192 public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) { 193 this.taskScheduler = taskScheduler; 194 if (taskScheduler != null && this.heartbeatValue == null) { 195 this.heartbeatValue = new long[] {10000, 10000}; 196 } 197 } 198 199 /** 200 * Return the configured TaskScheduler. 201 * @since 4.2 202 */ 203 @Nullable 204 public TaskScheduler getTaskScheduler() { 205 return this.taskScheduler; 206 } 207 208 /** 209 * Configure the value for the heart-beat settings. The first number 210 * represents how often the server will write or send a heartbeat. 211 * The second is how often the client should write. 0 means no heartbeats. 212 * <p>By default this is set to "0, 0" unless the {@link #setTaskScheduler 213 * taskScheduler} in which case the default becomes "10000,10000" 214 * (in milliseconds). 215 * @since 4.2 216 */ 217 public void setHeartbeatValue(@Nullable long[] heartbeat) { 218 if (heartbeat != null && (heartbeat.length != 2 || heartbeat[0] < 0 || heartbeat[1] < 0)) { 219 throw new IllegalArgumentException("Invalid heart-beat: " + Arrays.toString(heartbeat)); 220 } 221 this.heartbeatValue = heartbeat; 222 } 223 224 /** 225 * The configured value for the heart-beat settings. 226 * @since 4.2 227 */ 228 @Nullable 229 public long[] getHeartbeatValue() { 230 return this.heartbeatValue; 231 } 232 233 /** 234 * Configure a {@link MessageHeaderInitializer} to apply to the headers 235 * of all messages sent to the client outbound channel. 236 * <p>By default this property is not set. 237 * @since 4.1 238 */ 239 public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) { 240 this.headerInitializer = headerInitializer; 241 } 242 243 /** 244 * Return the configured header initializer. 245 * @since 4.1 246 */ 247 @Nullable 248 public MessageHeaderInitializer getHeaderInitializer() { 249 return this.headerInitializer; 250 } 251 252 253 @Override 254 public void startInternal() { 255 publishBrokerAvailableEvent(); 256 if (this.taskScheduler != null) { 257 long interval = initHeartbeatTaskDelay(); 258 if (interval > 0) { 259 this.heartbeatFuture = this.taskScheduler.scheduleWithFixedDelay(new HeartbeatTask(), interval); 260 } 261 } 262 else { 263 Assert.isTrue(getHeartbeatValue() == null || 264 (getHeartbeatValue()[0] == 0 && getHeartbeatValue()[1] == 0), 265 "Heartbeat values configured but no TaskScheduler provided"); 266 } 267 } 268 269 private long initHeartbeatTaskDelay() { 270 if (getHeartbeatValue() == null) { 271 return 0; 272 } 273 else if (getHeartbeatValue()[0] > 0 && getHeartbeatValue()[1] > 0) { 274 return Math.min(getHeartbeatValue()[0], getHeartbeatValue()[1]); 275 } 276 else { 277 return (getHeartbeatValue()[0] > 0 ? getHeartbeatValue()[0] : getHeartbeatValue()[1]); 278 } 279 } 280 281 @Override 282 public void stopInternal() { 283 publishBrokerUnavailableEvent(); 284 if (this.heartbeatFuture != null) { 285 this.heartbeatFuture.cancel(true); 286 } 287 } 288 289 @Override 290 protected void handleMessageInternal(Message<?> message) { 291 MessageHeaders headers = message.getHeaders(); 292 String destination = SimpMessageHeaderAccessor.getDestination(headers); 293 String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); 294 295 updateSessionReadTime(sessionId); 296 297 if (!checkDestinationPrefix(destination)) { 298 return; 299 } 300 301 SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); 302 if (SimpMessageType.MESSAGE.equals(messageType)) { 303 logMessage(message); 304 sendMessageToSubscribers(destination, message); 305 } 306 else if (SimpMessageType.CONNECT.equals(messageType)) { 307 logMessage(message); 308 if (sessionId != null) { 309 long[] heartbeatIn = SimpMessageHeaderAccessor.getHeartbeat(headers); 310 long[] heartbeatOut = getHeartbeatValue(); 311 Principal user = SimpMessageHeaderAccessor.getUser(headers); 312 MessageChannel outChannel = getClientOutboundChannelForSession(sessionId); 313 this.sessions.put(sessionId, new SessionInfo(sessionId, user, outChannel, heartbeatIn, heartbeatOut)); 314 SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); 315 initHeaders(connectAck); 316 connectAck.setSessionId(sessionId); 317 if (user != null) { 318 connectAck.setUser(user); 319 } 320 connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); 321 connectAck.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, heartbeatOut); 322 Message<byte[]> messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders()); 323 getClientOutboundChannel().send(messageOut); 324 } 325 } 326 else if (SimpMessageType.DISCONNECT.equals(messageType)) { 327 logMessage(message); 328 if (sessionId != null) { 329 Principal user = SimpMessageHeaderAccessor.getUser(headers); 330 handleDisconnect(sessionId, user, message); 331 } 332 } 333 else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { 334 logMessage(message); 335 this.subscriptionRegistry.registerSubscription(message); 336 } 337 else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { 338 logMessage(message); 339 this.subscriptionRegistry.unregisterSubscription(message); 340 } 341 } 342 343 private void updateSessionReadTime(@Nullable String sessionId) { 344 if (sessionId != null) { 345 SessionInfo info = this.sessions.get(sessionId); 346 if (info != null) { 347 info.setLastReadTime(System.currentTimeMillis()); 348 } 349 } 350 } 351 352 private void logMessage(Message<?> message) { 353 if (logger.isDebugEnabled()) { 354 SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); 355 accessor = (accessor != null ? accessor : SimpMessageHeaderAccessor.wrap(message)); 356 logger.debug("Processing " + accessor.getShortLogMessage(message.getPayload())); 357 } 358 } 359 360 private void initHeaders(SimpMessageHeaderAccessor accessor) { 361 if (getHeaderInitializer() != null) { 362 getHeaderInitializer().initHeaders(accessor); 363 } 364 } 365 366 private void handleDisconnect(String sessionId, @Nullable Principal user, @Nullable Message<?> origMessage) { 367 this.sessions.remove(sessionId); 368 this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); 369 SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); 370 accessor.setSessionId(sessionId); 371 if (user != null) { 372 accessor.setUser(user); 373 } 374 if (origMessage != null) { 375 accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage); 376 } 377 initHeaders(accessor); 378 Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); 379 getClientOutboundChannel().send(message); 380 } 381 382 protected void sendMessageToSubscribers(@Nullable String destination, Message<?> message) { 383 MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message); 384 if (!subscriptions.isEmpty() && logger.isDebugEnabled()) { 385 logger.debug("Broadcasting to " + subscriptions.size() + " sessions."); 386 } 387 long now = System.currentTimeMillis(); 388 subscriptions.forEach((sessionId, subscriptionIds) -> { 389 for (String subscriptionId : subscriptionIds) { 390 SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); 391 initHeaders(headerAccessor); 392 headerAccessor.setSessionId(sessionId); 393 headerAccessor.setSubscriptionId(subscriptionId); 394 headerAccessor.copyHeadersIfAbsent(message.getHeaders()); 395 headerAccessor.setLeaveMutable(true); 396 Object payload = message.getPayload(); 397 Message<?> reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders()); 398 SessionInfo info = this.sessions.get(sessionId); 399 if (info != null) { 400 try { 401 info.getClientOutboundChannel().send(reply); 402 } 403 catch (Throwable ex) { 404 if (logger.isErrorEnabled()) { 405 logger.error("Failed to send " + message, ex); 406 } 407 } 408 finally { 409 info.setLastWriteTime(now); 410 } 411 } 412 } 413 }); 414 } 415 416 @Override 417 public String toString() { 418 return "SimpleBrokerMessageHandler [" + this.subscriptionRegistry + "]"; 419 } 420 421 422 private static class SessionInfo { 423 424 /* STOMP spec: receiver SHOULD take into account an error margin */ 425 private static final long HEARTBEAT_MULTIPLIER = 3; 426 427 private final String sessionId; 428 429 @Nullable 430 private final Principal user; 431 432 private final MessageChannel clientOutboundChannel; 433 434 private final long readInterval; 435 436 private final long writeInterval; 437 438 private volatile long lastReadTime; 439 440 private volatile long lastWriteTime; 441 442 443 public SessionInfo(String sessionId, @Nullable Principal user, MessageChannel outboundChannel, 444 @Nullable long[] clientHeartbeat, @Nullable long[] serverHeartbeat) { 445 446 this.sessionId = sessionId; 447 this.user = user; 448 this.clientOutboundChannel = outboundChannel; 449 if (clientHeartbeat != null && serverHeartbeat != null) { 450 this.readInterval = (clientHeartbeat[0] > 0 && serverHeartbeat[1] > 0 ? 451 Math.max(clientHeartbeat[0], serverHeartbeat[1]) * HEARTBEAT_MULTIPLIER : 0); 452 this.writeInterval = (clientHeartbeat[1] > 0 && serverHeartbeat[0] > 0 ? 453 Math.max(clientHeartbeat[1], serverHeartbeat[0]) : 0); 454 } 455 else { 456 this.readInterval = 0; 457 this.writeInterval = 0; 458 } 459 this.lastReadTime = this.lastWriteTime = System.currentTimeMillis(); 460 } 461 462 public String getSessionId() { 463 return this.sessionId; 464 } 465 466 @Nullable 467 public Principal getUser() { 468 return this.user; 469 } 470 471 public MessageChannel getClientOutboundChannel() { 472 return this.clientOutboundChannel; 473 } 474 475 public long getReadInterval() { 476 return this.readInterval; 477 } 478 479 public long getWriteInterval() { 480 return this.writeInterval; 481 } 482 483 public long getLastReadTime() { 484 return this.lastReadTime; 485 } 486 487 public void setLastReadTime(long lastReadTime) { 488 this.lastReadTime = lastReadTime; 489 } 490 491 public long getLastWriteTime() { 492 return this.lastWriteTime; 493 } 494 495 public void setLastWriteTime(long lastWriteTime) { 496 this.lastWriteTime = lastWriteTime; 497 } 498 } 499 500 501 private class HeartbeatTask implements Runnable { 502 503 @Override 504 public void run() { 505 long now = System.currentTimeMillis(); 506 for (SessionInfo info : sessions.values()) { 507 if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) { 508 handleDisconnect(info.getSessionId(), info.getUser(), null); 509 } 510 if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) { 511 SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); 512 accessor.setSessionId(info.getSessionId()); 513 Principal user = info.getUser(); 514 if (user != null) { 515 accessor.setUser(user); 516 } 517 initHeaders(accessor); 518 accessor.setLeaveMutable(true); 519 MessageHeaders headers = accessor.getMessageHeaders(); 520 info.getClientOutboundChannel().send(MessageBuilder.createMessage(EMPTY_PAYLOAD, headers)); 521 } 522 } 523 } 524 } 525 526}