001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.messaging.simp.broker; 018 019import java.util.Collection; 020import java.util.HashSet; 021import java.util.LinkedHashMap; 022import java.util.List; 023import java.util.Map; 024import java.util.Set; 025import java.util.concurrent.ConcurrentHashMap; 026import java.util.concurrent.ConcurrentMap; 027import java.util.concurrent.CopyOnWriteArraySet; 028 029import org.springframework.expression.EvaluationContext; 030import org.springframework.expression.Expression; 031import org.springframework.expression.ExpressionParser; 032import org.springframework.expression.PropertyAccessor; 033import org.springframework.expression.TypedValue; 034import org.springframework.expression.spel.SpelEvaluationException; 035import org.springframework.expression.spel.standard.SpelExpressionParser; 036import org.springframework.expression.spel.support.SimpleEvaluationContext; 037import org.springframework.lang.Nullable; 038import org.springframework.messaging.Message; 039import org.springframework.messaging.MessageHeaders; 040import org.springframework.messaging.simp.SimpMessageHeaderAccessor; 041import org.springframework.messaging.support.MessageHeaderAccessor; 042import org.springframework.util.AntPathMatcher; 043import org.springframework.util.Assert; 044import org.springframework.util.LinkedMultiValueMap; 045import org.springframework.util.MultiValueMap; 046import org.springframework.util.PathMatcher; 047import org.springframework.util.StringUtils; 048 049/** 050 * Implementation of {@link SubscriptionRegistry} that stores subscriptions 051 * in memory and uses a {@link org.springframework.util.PathMatcher PathMatcher} 052 * for matching destinations. 053 * 054 * <p>As of 4.2, this class supports a {@link #setSelectorHeaderName selector} 055 * header on subscription messages with Spring EL expressions evaluated against 056 * the headers to filter out messages in addition to destination matching. 057 * 058 * @author Rossen Stoyanchev 059 * @author Sebastien Deleuze 060 * @author Juergen Hoeller 061 * @since 4.0 062 */ 063public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { 064 065 /** Default maximum number of entries for the destination cache: 1024. */ 066 public static final int DEFAULT_CACHE_LIMIT = 1024; 067 068 /** Static evaluation context to reuse. */ 069 private static final EvaluationContext messageEvalContext = 070 SimpleEvaluationContext.forPropertyAccessors(new SimpMessageHeaderPropertyAccessor()).build(); 071 072 073 private PathMatcher pathMatcher = new AntPathMatcher(); 074 075 private volatile int cacheLimit = DEFAULT_CACHE_LIMIT; 076 077 @Nullable 078 private String selectorHeaderName = "selector"; 079 080 private volatile boolean selectorHeaderInUse = false; 081 082 private final ExpressionParser expressionParser = new SpelExpressionParser(); 083 084 private final DestinationCache destinationCache = new DestinationCache(); 085 086 private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry(); 087 088 089 /** 090 * Specify the {@link PathMatcher} to use. 091 */ 092 public void setPathMatcher(PathMatcher pathMatcher) { 093 this.pathMatcher = pathMatcher; 094 } 095 096 /** 097 * Return the configured {@link PathMatcher}. 098 */ 099 public PathMatcher getPathMatcher() { 100 return this.pathMatcher; 101 } 102 103 /** 104 * Specify the maximum number of entries for the resolved destination cache. 105 * Default is 1024. 106 */ 107 public void setCacheLimit(int cacheLimit) { 108 this.cacheLimit = cacheLimit; 109 } 110 111 /** 112 * Return the maximum number of entries for the resolved destination cache. 113 */ 114 public int getCacheLimit() { 115 return this.cacheLimit; 116 } 117 118 /** 119 * Configure the name of a header that a subscription message can have for 120 * the purpose of filtering messages matched to the subscription. The header 121 * value is expected to be a Spring EL boolean expression to be applied to 122 * the headers of messages matched to the subscription. 123 * <p>For example: 124 * <pre> 125 * headers.foo == 'bar' 126 * </pre> 127 * <p>By default this is set to "selector". You can set it to a different 128 * name, or to {@code null} to turn off support for a selector header. 129 * @param selectorHeaderName the name to use for a selector header 130 * @since 4.2 131 */ 132 public void setSelectorHeaderName(@Nullable String selectorHeaderName) { 133 this.selectorHeaderName = (StringUtils.hasText(selectorHeaderName) ? selectorHeaderName : null); 134 } 135 136 /** 137 * Return the name for the selector header name. 138 * @since 4.2 139 */ 140 @Nullable 141 public String getSelectorHeaderName() { 142 return this.selectorHeaderName; 143 } 144 145 146 @Override 147 protected void addSubscriptionInternal( 148 String sessionId, String subsId, String destination, Message<?> message) { 149 150 Expression expression = getSelectorExpression(message.getHeaders()); 151 this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression); 152 this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId); 153 } 154 155 @Nullable 156 private Expression getSelectorExpression(MessageHeaders headers) { 157 Expression expression = null; 158 if (getSelectorHeaderName() != null) { 159 String selector = SimpMessageHeaderAccessor.getFirstNativeHeader(getSelectorHeaderName(), headers); 160 if (selector != null) { 161 try { 162 expression = this.expressionParser.parseExpression(selector); 163 this.selectorHeaderInUse = true; 164 if (logger.isTraceEnabled()) { 165 logger.trace("Subscription selector: [" + selector + "]"); 166 } 167 } 168 catch (Throwable ex) { 169 if (logger.isDebugEnabled()) { 170 logger.debug("Failed to parse selector: " + selector, ex); 171 } 172 } 173 } 174 } 175 return expression; 176 } 177 178 @Override 179 protected void removeSubscriptionInternal(String sessionId, String subsId, Message<?> message) { 180 SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); 181 if (info != null) { 182 String destination = info.removeSubscription(subsId); 183 if (destination != null) { 184 this.destinationCache.updateAfterRemovedSubscription(sessionId, subsId); 185 } 186 } 187 } 188 189 @Override 190 public void unregisterAllSubscriptions(String sessionId) { 191 SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId); 192 if (info != null) { 193 this.destinationCache.updateAfterRemovedSession(info); 194 } 195 } 196 197 @Override 198 protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) { 199 MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination, message); 200 return filterSubscriptions(result, message); 201 } 202 203 private MultiValueMap<String, String> filterSubscriptions( 204 MultiValueMap<String, String> allMatches, Message<?> message) { 205 206 if (!this.selectorHeaderInUse) { 207 return allMatches; 208 } 209 MultiValueMap<String, String> result = new LinkedMultiValueMap<>(allMatches.size()); 210 allMatches.forEach((sessionId, subIds) -> { 211 for (String subId : subIds) { 212 SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); 213 if (info == null) { 214 continue; 215 } 216 Subscription sub = info.getSubscription(subId); 217 if (sub == null) { 218 continue; 219 } 220 Expression expression = sub.getSelectorExpression(); 221 if (expression == null) { 222 result.add(sessionId, subId); 223 continue; 224 } 225 try { 226 if (Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) { 227 result.add(sessionId, subId); 228 } 229 } 230 catch (SpelEvaluationException ex) { 231 if (logger.isDebugEnabled()) { 232 logger.debug("Failed to evaluate selector: " + ex.getMessage()); 233 } 234 } 235 catch (Throwable ex) { 236 logger.debug("Failed to evaluate selector", ex); 237 } 238 } 239 }); 240 return result; 241 } 242 243 @Override 244 public String toString() { 245 return "DefaultSubscriptionRegistry[" + this.destinationCache + ", " + this.subscriptionRegistry + "]"; 246 } 247 248 249 /** 250 * A cache for destinations previously resolved via 251 * {@link DefaultSubscriptionRegistry#findSubscriptionsInternal(String, Message)}. 252 */ 253 private class DestinationCache { 254 255 /** Map from destination to {@code <sessionId, subscriptionId>} for fast look-ups. */ 256 private final Map<String, LinkedMultiValueMap<String, String>> accessCache = 257 new ConcurrentHashMap<>(DEFAULT_CACHE_LIMIT); 258 259 /** Map from destination to {@code <sessionId, subscriptionId>} with locking. */ 260 @SuppressWarnings("serial") 261 private final Map<String, LinkedMultiValueMap<String, String>> updateCache = 262 new LinkedHashMap<String, LinkedMultiValueMap<String, String>>(DEFAULT_CACHE_LIMIT, 0.75f, true) { 263 @Override 264 protected boolean removeEldestEntry(Map.Entry<String, LinkedMultiValueMap<String, String>> eldest) { 265 if (size() > getCacheLimit()) { 266 accessCache.remove(eldest.getKey()); 267 return true; 268 } 269 else { 270 return false; 271 } 272 } 273 }; 274 275 276 public LinkedMultiValueMap<String, String> getSubscriptions(String destination, Message<?> message) { 277 LinkedMultiValueMap<String, String> result = this.accessCache.get(destination); 278 if (result == null) { 279 synchronized (this.updateCache) { 280 result = new LinkedMultiValueMap<>(); 281 for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) { 282 for (String destinationPattern : info.getDestinations()) { 283 if (getPathMatcher().match(destinationPattern, destination)) { 284 for (Subscription sub : info.getSubscriptions(destinationPattern)) { 285 result.add(info.sessionId, sub.getId()); 286 } 287 } 288 } 289 } 290 if (!result.isEmpty()) { 291 this.updateCache.put(destination, result.deepCopy()); 292 this.accessCache.put(destination, result); 293 } 294 } 295 } 296 return result; 297 } 298 299 public void updateAfterNewSubscription(String destination, String sessionId, String subsId) { 300 synchronized (this.updateCache) { 301 this.updateCache.forEach((cachedDestination, subscriptions) -> { 302 if (getPathMatcher().match(destination, cachedDestination)) { 303 // Subscription id's may also be populated via getSubscriptions() 304 List<String> subsForSession = subscriptions.get(sessionId); 305 if (subsForSession == null || !subsForSession.contains(subsId)) { 306 subscriptions.add(sessionId, subsId); 307 this.accessCache.put(cachedDestination, subscriptions.deepCopy()); 308 } 309 } 310 }); 311 } 312 } 313 314 public void updateAfterRemovedSubscription(String sessionId, String subsId) { 315 synchronized (this.updateCache) { 316 Set<String> destinationsToRemove = new HashSet<>(); 317 this.updateCache.forEach((destination, sessionMap) -> { 318 List<String> subscriptions = sessionMap.get(sessionId); 319 if (subscriptions != null) { 320 subscriptions.remove(subsId); 321 if (subscriptions.isEmpty()) { 322 sessionMap.remove(sessionId); 323 } 324 if (sessionMap.isEmpty()) { 325 destinationsToRemove.add(destination); 326 } 327 else { 328 this.accessCache.put(destination, sessionMap.deepCopy()); 329 } 330 } 331 }); 332 for (String destination : destinationsToRemove) { 333 this.updateCache.remove(destination); 334 this.accessCache.remove(destination); 335 } 336 } 337 } 338 339 public void updateAfterRemovedSession(SessionSubscriptionInfo info) { 340 synchronized (this.updateCache) { 341 Set<String> destinationsToRemove = new HashSet<>(); 342 this.updateCache.forEach((destination, sessionMap) -> { 343 if (sessionMap.remove(info.getSessionId()) != null) { 344 if (sessionMap.isEmpty()) { 345 destinationsToRemove.add(destination); 346 } 347 else { 348 this.accessCache.put(destination, sessionMap.deepCopy()); 349 } 350 } 351 }); 352 for (String destination : destinationsToRemove) { 353 this.updateCache.remove(destination); 354 this.accessCache.remove(destination); 355 } 356 } 357 } 358 359 @Override 360 public String toString() { 361 return "cache[" + this.accessCache.size() + " destination(s)]"; 362 } 363 } 364 365 366 /** 367 * Provide access to session subscriptions by sessionId. 368 */ 369 private static class SessionSubscriptionRegistry { 370 371 // sessionId -> SessionSubscriptionInfo 372 private final ConcurrentMap<String, SessionSubscriptionInfo> sessions = new ConcurrentHashMap<>(); 373 374 @Nullable 375 public SessionSubscriptionInfo getSubscriptions(String sessionId) { 376 return this.sessions.get(sessionId); 377 } 378 379 public Collection<SessionSubscriptionInfo> getAllSubscriptions() { 380 return this.sessions.values(); 381 } 382 383 public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, 384 String destination, @Nullable Expression selectorExpression) { 385 386 SessionSubscriptionInfo info = this.sessions.get(sessionId); 387 if (info == null) { 388 info = new SessionSubscriptionInfo(sessionId); 389 SessionSubscriptionInfo value = this.sessions.putIfAbsent(sessionId, info); 390 if (value != null) { 391 info = value; 392 } 393 } 394 info.addSubscription(destination, subscriptionId, selectorExpression); 395 return info; 396 } 397 398 @Nullable 399 public SessionSubscriptionInfo removeSubscriptions(String sessionId) { 400 return this.sessions.remove(sessionId); 401 } 402 403 @Override 404 public String toString() { 405 return "registry[" + this.sessions.size() + " sessions]"; 406 } 407 } 408 409 410 /** 411 * Hold subscriptions for a session. 412 */ 413 private static class SessionSubscriptionInfo { 414 415 private final String sessionId; 416 417 // destination -> subscriptions 418 private final Map<String, Set<Subscription>> destinationLookup = new ConcurrentHashMap<>(4); 419 420 public SessionSubscriptionInfo(String sessionId) { 421 Assert.notNull(sessionId, "'sessionId' must not be null"); 422 this.sessionId = sessionId; 423 } 424 425 public String getSessionId() { 426 return this.sessionId; 427 } 428 429 public Set<String> getDestinations() { 430 return this.destinationLookup.keySet(); 431 } 432 433 public Set<Subscription> getSubscriptions(String destination) { 434 return this.destinationLookup.get(destination); 435 } 436 437 @Nullable 438 public Subscription getSubscription(String subscriptionId) { 439 for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry : 440 this.destinationLookup.entrySet()) { 441 for (Subscription sub : destinationEntry.getValue()) { 442 if (sub.getId().equalsIgnoreCase(subscriptionId)) { 443 return sub; 444 } 445 } 446 } 447 return null; 448 } 449 450 public void addSubscription(String destination, String subscriptionId, @Nullable Expression selectorExpression) { 451 Set<Subscription> subs = this.destinationLookup.get(destination); 452 if (subs == null) { 453 synchronized (this.destinationLookup) { 454 subs = this.destinationLookup.get(destination); 455 if (subs == null) { 456 subs = new CopyOnWriteArraySet<>(); 457 this.destinationLookup.put(destination, subs); 458 } 459 } 460 } 461 subs.add(new Subscription(subscriptionId, selectorExpression)); 462 } 463 464 @Nullable 465 public String removeSubscription(String subscriptionId) { 466 for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry : 467 this.destinationLookup.entrySet()) { 468 Set<Subscription> subs = destinationEntry.getValue(); 469 if (subs != null) { 470 for (Subscription sub : subs) { 471 if (sub.getId().equals(subscriptionId) && subs.remove(sub)) { 472 synchronized (this.destinationLookup) { 473 if (subs.isEmpty()) { 474 this.destinationLookup.remove(destinationEntry.getKey()); 475 } 476 } 477 return destinationEntry.getKey(); 478 } 479 } 480 } 481 } 482 return null; 483 } 484 485 @Override 486 public String toString() { 487 return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]"; 488 } 489 } 490 491 492 private static final class Subscription { 493 494 private final String id; 495 496 @Nullable 497 private final Expression selectorExpression; 498 499 public Subscription(String id, @Nullable Expression selector) { 500 Assert.notNull(id, "Subscription id must not be null"); 501 this.id = id; 502 this.selectorExpression = selector; 503 } 504 505 public String getId() { 506 return this.id; 507 } 508 509 @Nullable 510 public Expression getSelectorExpression() { 511 return this.selectorExpression; 512 } 513 514 @Override 515 public boolean equals(@Nullable Object other) { 516 return (this == other || (other instanceof Subscription && this.id.equals(((Subscription) other).id))); 517 } 518 519 @Override 520 public int hashCode() { 521 return this.id.hashCode(); 522 } 523 524 @Override 525 public String toString() { 526 return "subscription(id=" + this.id + ")"; 527 } 528 } 529 530 531 private static class SimpMessageHeaderPropertyAccessor implements PropertyAccessor { 532 533 @Override 534 public Class<?>[] getSpecificTargetClasses() { 535 return new Class<?>[] {Message.class, MessageHeaders.class}; 536 } 537 538 @Override 539 public boolean canRead(EvaluationContext context, @Nullable Object target, String name) { 540 return true; 541 } 542 543 @Override 544 @SuppressWarnings("rawtypes") 545 public TypedValue read(EvaluationContext context, @Nullable Object target, String name) { 546 Object value; 547 if (target instanceof Message) { 548 value = name.equals("headers") ? ((Message) target).getHeaders() : null; 549 } 550 else if (target instanceof MessageHeaders) { 551 MessageHeaders headers = (MessageHeaders) target; 552 SimpMessageHeaderAccessor accessor = 553 MessageHeaderAccessor.getAccessor(headers, SimpMessageHeaderAccessor.class); 554 Assert.state(accessor != null, "No SimpMessageHeaderAccessor"); 555 if ("destination".equalsIgnoreCase(name)) { 556 value = accessor.getDestination(); 557 } 558 else { 559 value = accessor.getFirstNativeHeader(name); 560 if (value == null) { 561 value = headers.get(name); 562 } 563 } 564 } 565 else { 566 // Should never happen... 567 throw new IllegalStateException("Expected Message or MessageHeaders."); 568 } 569 return new TypedValue(value); 570 } 571 572 @Override 573 public boolean canWrite(EvaluationContext context, @Nullable Object target, String name) { 574 return false; 575 } 576 577 @Override 578 public void write(EvaluationContext context, @Nullable Object target, String name, @Nullable Object value) { 579 } 580 } 581 582}