001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.broker;
018
019import java.util.Collection;
020import java.util.HashSet;
021import java.util.LinkedHashMap;
022import java.util.List;
023import java.util.Map;
024import java.util.Set;
025import java.util.concurrent.ConcurrentHashMap;
026import java.util.concurrent.ConcurrentMap;
027import java.util.concurrent.CopyOnWriteArraySet;
028
029import org.springframework.expression.EvaluationContext;
030import org.springframework.expression.Expression;
031import org.springframework.expression.ExpressionParser;
032import org.springframework.expression.PropertyAccessor;
033import org.springframework.expression.TypedValue;
034import org.springframework.expression.spel.SpelEvaluationException;
035import org.springframework.expression.spel.standard.SpelExpressionParser;
036import org.springframework.expression.spel.support.SimpleEvaluationContext;
037import org.springframework.lang.Nullable;
038import org.springframework.messaging.Message;
039import org.springframework.messaging.MessageHeaders;
040import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
041import org.springframework.messaging.support.MessageHeaderAccessor;
042import org.springframework.util.AntPathMatcher;
043import org.springframework.util.Assert;
044import org.springframework.util.LinkedMultiValueMap;
045import org.springframework.util.MultiValueMap;
046import org.springframework.util.PathMatcher;
047import org.springframework.util.StringUtils;
048
049/**
050 * Implementation of {@link SubscriptionRegistry} that stores subscriptions
051 * in memory and uses a {@link org.springframework.util.PathMatcher PathMatcher}
052 * for matching destinations.
053 *
054 * <p>As of 4.2, this class supports a {@link #setSelectorHeaderName selector}
055 * header on subscription messages with Spring EL expressions evaluated against
056 * the headers to filter out messages in addition to destination matching.
057 *
058 * @author Rossen Stoyanchev
059 * @author Sebastien Deleuze
060 * @author Juergen Hoeller
061 * @since 4.0
062 */
063public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
064
065        /** Default maximum number of entries for the destination cache: 1024. */
066        public static final int DEFAULT_CACHE_LIMIT = 1024;
067
068        /** Static evaluation context to reuse. */
069        private static final EvaluationContext messageEvalContext =
070                        SimpleEvaluationContext.forPropertyAccessors(new SimpMessageHeaderPropertyAccessor()).build();
071
072
073        private PathMatcher pathMatcher = new AntPathMatcher();
074
075        private volatile int cacheLimit = DEFAULT_CACHE_LIMIT;
076
077        @Nullable
078        private String selectorHeaderName = "selector";
079
080        private volatile boolean selectorHeaderInUse = false;
081
082        private final ExpressionParser expressionParser = new SpelExpressionParser();
083
084        private final DestinationCache destinationCache = new DestinationCache();
085
086        private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry();
087
088
089        /**
090         * Specify the {@link PathMatcher} to use.
091         */
092        public void setPathMatcher(PathMatcher pathMatcher) {
093                this.pathMatcher = pathMatcher;
094        }
095
096        /**
097         * Return the configured {@link PathMatcher}.
098         */
099        public PathMatcher getPathMatcher() {
100                return this.pathMatcher;
101        }
102
103        /**
104         * Specify the maximum number of entries for the resolved destination cache.
105         * Default is 1024.
106         */
107        public void setCacheLimit(int cacheLimit) {
108                this.cacheLimit = cacheLimit;
109        }
110
111        /**
112         * Return the maximum number of entries for the resolved destination cache.
113         */
114        public int getCacheLimit() {
115                return this.cacheLimit;
116        }
117
118        /**
119         * Configure the name of a header that a subscription message can have for
120         * the purpose of filtering messages matched to the subscription. The header
121         * value is expected to be a Spring EL boolean expression to be applied to
122         * the headers of messages matched to the subscription.
123         * <p>For example:
124         * <pre>
125         * headers.foo == 'bar'
126         * </pre>
127         * <p>By default this is set to "selector". You can set it to a different
128         * name, or to {@code null} to turn off support for a selector header.
129         * @param selectorHeaderName the name to use for a selector header
130         * @since 4.2
131         */
132        public void setSelectorHeaderName(@Nullable String selectorHeaderName) {
133                this.selectorHeaderName = (StringUtils.hasText(selectorHeaderName) ? selectorHeaderName : null);
134        }
135
136        /**
137         * Return the name for the selector header name.
138         * @since 4.2
139         */
140        @Nullable
141        public String getSelectorHeaderName() {
142                return this.selectorHeaderName;
143        }
144
145
146        @Override
147        protected void addSubscriptionInternal(
148                        String sessionId, String subsId, String destination, Message<?> message) {
149
150                Expression expression = getSelectorExpression(message.getHeaders());
151                this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression);
152                this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId);
153        }
154
155        @Nullable
156        private Expression getSelectorExpression(MessageHeaders headers) {
157                Expression expression = null;
158                if (getSelectorHeaderName() != null) {
159                        String selector = SimpMessageHeaderAccessor.getFirstNativeHeader(getSelectorHeaderName(), headers);
160                        if (selector != null) {
161                                try {
162                                        expression = this.expressionParser.parseExpression(selector);
163                                        this.selectorHeaderInUse = true;
164                                        if (logger.isTraceEnabled()) {
165                                                logger.trace("Subscription selector: [" + selector + "]");
166                                        }
167                                }
168                                catch (Throwable ex) {
169                                        if (logger.isDebugEnabled()) {
170                                                logger.debug("Failed to parse selector: " + selector, ex);
171                                        }
172                                }
173                        }
174                }
175                return expression;
176        }
177
178        @Override
179        protected void removeSubscriptionInternal(String sessionId, String subsId, Message<?> message) {
180                SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
181                if (info != null) {
182                        String destination = info.removeSubscription(subsId);
183                        if (destination != null) {
184                                this.destinationCache.updateAfterRemovedSubscription(sessionId, subsId);
185                        }
186                }
187        }
188
189        @Override
190        public void unregisterAllSubscriptions(String sessionId) {
191                SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId);
192                if (info != null) {
193                        this.destinationCache.updateAfterRemovedSession(info);
194                }
195        }
196
197        @Override
198        protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
199                MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination, message);
200                return filterSubscriptions(result, message);
201        }
202
203        private MultiValueMap<String, String> filterSubscriptions(
204                        MultiValueMap<String, String> allMatches, Message<?> message) {
205
206                if (!this.selectorHeaderInUse) {
207                        return allMatches;
208                }
209                MultiValueMap<String, String> result = new LinkedMultiValueMap<>(allMatches.size());
210                allMatches.forEach((sessionId, subIds) -> {
211                        for (String subId : subIds) {
212                                SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
213                                if (info == null) {
214                                        continue;
215                                }
216                                Subscription sub = info.getSubscription(subId);
217                                if (sub == null) {
218                                        continue;
219                                }
220                                Expression expression = sub.getSelectorExpression();
221                                if (expression == null) {
222                                        result.add(sessionId, subId);
223                                        continue;
224                                }
225                                try {
226                                        if (Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) {
227                                                result.add(sessionId, subId);
228                                        }
229                                }
230                                catch (SpelEvaluationException ex) {
231                                        if (logger.isDebugEnabled()) {
232                                                logger.debug("Failed to evaluate selector: " + ex.getMessage());
233                                        }
234                                }
235                                catch (Throwable ex) {
236                                        logger.debug("Failed to evaluate selector", ex);
237                                }
238                        }
239                });
240                return result;
241        }
242
243        @Override
244        public String toString() {
245                return "DefaultSubscriptionRegistry[" + this.destinationCache + ", " + this.subscriptionRegistry + "]";
246        }
247
248
249        /**
250         * A cache for destinations previously resolved via
251         * {@link DefaultSubscriptionRegistry#findSubscriptionsInternal(String, Message)}.
252         */
253        private class DestinationCache {
254
255                /** Map from destination to {@code <sessionId, subscriptionId>} for fast look-ups. */
256                private final Map<String, LinkedMultiValueMap<String, String>> accessCache =
257                                new ConcurrentHashMap<>(DEFAULT_CACHE_LIMIT);
258
259                /** Map from destination to {@code <sessionId, subscriptionId>} with locking. */
260                @SuppressWarnings("serial")
261                private final Map<String, LinkedMultiValueMap<String, String>> updateCache =
262                                new LinkedHashMap<String, LinkedMultiValueMap<String, String>>(DEFAULT_CACHE_LIMIT, 0.75f, true) {
263                                        @Override
264                                        protected boolean removeEldestEntry(Map.Entry<String, LinkedMultiValueMap<String, String>> eldest) {
265                                                if (size() > getCacheLimit()) {
266                                                        accessCache.remove(eldest.getKey());
267                                                        return true;
268                                                }
269                                                else {
270                                                        return false;
271                                                }
272                                        }
273                                };
274
275
276                public LinkedMultiValueMap<String, String> getSubscriptions(String destination, Message<?> message) {
277                        LinkedMultiValueMap<String, String> result = this.accessCache.get(destination);
278                        if (result == null) {
279                                synchronized (this.updateCache) {
280                                        result = new LinkedMultiValueMap<>();
281                                        for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) {
282                                                for (String destinationPattern : info.getDestinations()) {
283                                                        if (getPathMatcher().match(destinationPattern, destination)) {
284                                                                for (Subscription sub : info.getSubscriptions(destinationPattern)) {
285                                                                        result.add(info.sessionId, sub.getId());
286                                                                }
287                                                        }
288                                                }
289                                        }
290                                        if (!result.isEmpty()) {
291                                                this.updateCache.put(destination, result.deepCopy());
292                                                this.accessCache.put(destination, result);
293                                        }
294                                }
295                        }
296                        return result;
297                }
298
299                public void updateAfterNewSubscription(String destination, String sessionId, String subsId) {
300                        synchronized (this.updateCache) {
301                                this.updateCache.forEach((cachedDestination, subscriptions) -> {
302                                        if (getPathMatcher().match(destination, cachedDestination)) {
303                                                // Subscription id's may also be populated via getSubscriptions()
304                                                List<String> subsForSession = subscriptions.get(sessionId);
305                                                if (subsForSession == null || !subsForSession.contains(subsId)) {
306                                                        subscriptions.add(sessionId, subsId);
307                                                        this.accessCache.put(cachedDestination, subscriptions.deepCopy());
308                                                }
309                                        }
310                                });
311                        }
312                }
313
314                public void updateAfterRemovedSubscription(String sessionId, String subsId) {
315                        synchronized (this.updateCache) {
316                                Set<String> destinationsToRemove = new HashSet<>();
317                                this.updateCache.forEach((destination, sessionMap) -> {
318                                        List<String> subscriptions = sessionMap.get(sessionId);
319                                        if (subscriptions != null) {
320                                                subscriptions.remove(subsId);
321                                                if (subscriptions.isEmpty()) {
322                                                        sessionMap.remove(sessionId);
323                                                }
324                                                if (sessionMap.isEmpty()) {
325                                                        destinationsToRemove.add(destination);
326                                                }
327                                                else {
328                                                        this.accessCache.put(destination, sessionMap.deepCopy());
329                                                }
330                                        }
331                                });
332                                for (String destination : destinationsToRemove) {
333                                        this.updateCache.remove(destination);
334                                        this.accessCache.remove(destination);
335                                }
336                        }
337                }
338
339                public void updateAfterRemovedSession(SessionSubscriptionInfo info) {
340                        synchronized (this.updateCache) {
341                                Set<String> destinationsToRemove = new HashSet<>();
342                                this.updateCache.forEach((destination, sessionMap) -> {
343                                        if (sessionMap.remove(info.getSessionId()) != null) {
344                                                if (sessionMap.isEmpty()) {
345                                                        destinationsToRemove.add(destination);
346                                                }
347                                                else {
348                                                        this.accessCache.put(destination, sessionMap.deepCopy());
349                                                }
350                                        }
351                                });
352                                for (String destination : destinationsToRemove) {
353                                        this.updateCache.remove(destination);
354                                        this.accessCache.remove(destination);
355                                }
356                        }
357                }
358
359                @Override
360                public String toString() {
361                        return "cache[" + this.accessCache.size() + " destination(s)]";
362                }
363        }
364
365
366        /**
367         * Provide access to session subscriptions by sessionId.
368         */
369        private static class SessionSubscriptionRegistry {
370
371                // sessionId -> SessionSubscriptionInfo
372                private final ConcurrentMap<String, SessionSubscriptionInfo> sessions = new ConcurrentHashMap<>();
373
374                @Nullable
375                public SessionSubscriptionInfo getSubscriptions(String sessionId) {
376                        return this.sessions.get(sessionId);
377                }
378
379                public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
380                        return this.sessions.values();
381                }
382
383                public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId,
384                                String destination, @Nullable Expression selectorExpression) {
385
386                        SessionSubscriptionInfo info = this.sessions.get(sessionId);
387                        if (info == null) {
388                                info = new SessionSubscriptionInfo(sessionId);
389                                SessionSubscriptionInfo value = this.sessions.putIfAbsent(sessionId, info);
390                                if (value != null) {
391                                        info = value;
392                                }
393                        }
394                        info.addSubscription(destination, subscriptionId, selectorExpression);
395                        return info;
396                }
397
398                @Nullable
399                public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
400                        return this.sessions.remove(sessionId);
401                }
402
403                @Override
404                public String toString() {
405                        return "registry[" + this.sessions.size() + " sessions]";
406                }
407        }
408
409
410        /**
411         * Hold subscriptions for a session.
412         */
413        private static class SessionSubscriptionInfo {
414
415                private final String sessionId;
416
417                // destination -> subscriptions
418                private final Map<String, Set<Subscription>> destinationLookup = new ConcurrentHashMap<>(4);
419
420                public SessionSubscriptionInfo(String sessionId) {
421                        Assert.notNull(sessionId, "'sessionId' must not be null");
422                        this.sessionId = sessionId;
423                }
424
425                public String getSessionId() {
426                        return this.sessionId;
427                }
428
429                public Set<String> getDestinations() {
430                        return this.destinationLookup.keySet();
431                }
432
433                public Set<Subscription> getSubscriptions(String destination) {
434                        return this.destinationLookup.get(destination);
435                }
436
437                @Nullable
438                public Subscription getSubscription(String subscriptionId) {
439                        for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry :
440                                        this.destinationLookup.entrySet()) {
441                                for (Subscription sub : destinationEntry.getValue()) {
442                                        if (sub.getId().equalsIgnoreCase(subscriptionId)) {
443                                                return sub;
444                                        }
445                                }
446                        }
447                        return null;
448                }
449
450                public void addSubscription(String destination, String subscriptionId, @Nullable Expression selectorExpression) {
451                        Set<Subscription> subs = this.destinationLookup.get(destination);
452                        if (subs == null) {
453                                synchronized (this.destinationLookup) {
454                                        subs = this.destinationLookup.get(destination);
455                                        if (subs == null) {
456                                                subs = new CopyOnWriteArraySet<>();
457                                                this.destinationLookup.put(destination, subs);
458                                        }
459                                }
460                        }
461                        subs.add(new Subscription(subscriptionId, selectorExpression));
462                }
463
464                @Nullable
465                public String removeSubscription(String subscriptionId) {
466                        for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry :
467                                        this.destinationLookup.entrySet()) {
468                                Set<Subscription> subs = destinationEntry.getValue();
469                                if (subs != null) {
470                                        for (Subscription sub : subs) {
471                                                if (sub.getId().equals(subscriptionId) && subs.remove(sub)) {
472                                                        synchronized (this.destinationLookup) {
473                                                                if (subs.isEmpty()) {
474                                                                        this.destinationLookup.remove(destinationEntry.getKey());
475                                                                }
476                                                        }
477                                                        return destinationEntry.getKey();
478                                                }
479                                        }
480                                }
481                        }
482                        return null;
483                }
484
485                @Override
486                public String toString() {
487                        return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]";
488                }
489        }
490
491
492        private static final class Subscription {
493
494                private final String id;
495
496                @Nullable
497                private final Expression selectorExpression;
498
499                public Subscription(String id, @Nullable Expression selector) {
500                        Assert.notNull(id, "Subscription id must not be null");
501                        this.id = id;
502                        this.selectorExpression = selector;
503                }
504
505                public String getId() {
506                        return this.id;
507                }
508
509                @Nullable
510                public Expression getSelectorExpression() {
511                        return this.selectorExpression;
512                }
513
514                @Override
515                public boolean equals(@Nullable Object other) {
516                        return (this == other || (other instanceof Subscription && this.id.equals(((Subscription) other).id)));
517                }
518
519                @Override
520                public int hashCode() {
521                        return this.id.hashCode();
522                }
523
524                @Override
525                public String toString() {
526                        return "subscription(id=" + this.id + ")";
527                }
528        }
529
530
531        private static class SimpMessageHeaderPropertyAccessor implements PropertyAccessor {
532
533                @Override
534                public Class<?>[] getSpecificTargetClasses() {
535                        return new Class<?>[] {Message.class, MessageHeaders.class};
536                }
537
538                @Override
539                public boolean canRead(EvaluationContext context, @Nullable Object target, String name) {
540                        return true;
541                }
542
543                @Override
544                @SuppressWarnings("rawtypes")
545                public TypedValue read(EvaluationContext context, @Nullable Object target, String name) {
546                        Object value;
547                        if (target instanceof Message) {
548                                value = name.equals("headers") ? ((Message) target).getHeaders() : null;
549                        }
550                        else if (target instanceof MessageHeaders) {
551                                MessageHeaders headers = (MessageHeaders) target;
552                                SimpMessageHeaderAccessor accessor =
553                                                MessageHeaderAccessor.getAccessor(headers, SimpMessageHeaderAccessor.class);
554                                Assert.state(accessor != null, "No SimpMessageHeaderAccessor");
555                                if ("destination".equalsIgnoreCase(name)) {
556                                        value = accessor.getDestination();
557                                }
558                                else {
559                                        value = accessor.getFirstNativeHeader(name);
560                                        if (value == null) {
561                                                value = headers.get(name);
562                                        }
563                                }
564                        }
565                        else {
566                                // Should never happen...
567                                throw new IllegalStateException("Expected Message or MessageHeaders.");
568                        }
569                        return new TypedValue(value);
570                }
571
572                @Override
573                public boolean canWrite(EvaluationContext context, @Nullable Object target, String name) {
574                        return false;
575                }
576
577                @Override
578                public void write(EvaluationContext context, @Nullable Object target, String name, @Nullable Object value) {
579                }
580        }
581
582}