001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.broker;
018
019import java.security.Principal;
020import java.util.Arrays;
021import java.util.Collection;
022import java.util.Map;
023import java.util.concurrent.ConcurrentHashMap;
024import java.util.concurrent.ScheduledFuture;
025
026import org.springframework.lang.Nullable;
027import org.springframework.messaging.Message;
028import org.springframework.messaging.MessageChannel;
029import org.springframework.messaging.MessageHeaders;
030import org.springframework.messaging.SubscribableChannel;
031import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
032import org.springframework.messaging.simp.SimpMessageType;
033import org.springframework.messaging.support.MessageBuilder;
034import org.springframework.messaging.support.MessageHeaderAccessor;
035import org.springframework.messaging.support.MessageHeaderInitializer;
036import org.springframework.scheduling.TaskScheduler;
037import org.springframework.util.Assert;
038import org.springframework.util.MultiValueMap;
039import org.springframework.util.PathMatcher;
040
041/**
042 * A "simple" message broker that recognizes the message types defined in
043 * {@link SimpMessageType}, keeps track of subscriptions with the help of a
044 * {@link SubscriptionRegistry} and sends messages to subscribers.
045 *
046 * @author Rossen Stoyanchev
047 * @author Juergen Hoeller
048 * @since 4.0
049 */
050public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
051
052        private static final byte[] EMPTY_PAYLOAD = new byte[0];
053
054
055        @Nullable
056        private PathMatcher pathMatcher;
057
058        @Nullable
059        private Integer cacheLimit;
060
061        @Nullable
062        private String selectorHeaderName = "selector";
063
064        @Nullable
065        private TaskScheduler taskScheduler;
066
067        @Nullable
068        private long[] heartbeatValue;
069
070        @Nullable
071        private MessageHeaderInitializer headerInitializer;
072
073
074        private SubscriptionRegistry subscriptionRegistry;
075
076        private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
077
078        @Nullable
079        private ScheduledFuture<?> heartbeatFuture;
080
081
082        /**
083         * Create a SimpleBrokerMessageHandler instance with the given message channels
084         * and destination prefixes.
085         * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients)
086         * @param clientOutboundChannel the channel for sending messages to clients (e.g. WebSocket clients)
087         * @param brokerChannel the channel for the application to send messages to the broker
088         * @param destinationPrefixes prefixes to use to filter out messages
089         */
090        public SimpleBrokerMessageHandler(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel,
091                        SubscribableChannel brokerChannel, Collection<String> destinationPrefixes) {
092
093                super(clientInboundChannel, clientOutboundChannel, brokerChannel, destinationPrefixes);
094                this.subscriptionRegistry = new DefaultSubscriptionRegistry();
095        }
096
097
098        /**
099         * Configure a custom SubscriptionRegistry to use for storing subscriptions.
100         * <p><strong>Note</strong> that when a custom PathMatcher is configured via
101         * {@link #setPathMatcher}, if the custom registry is not an instance of
102         * {@link DefaultSubscriptionRegistry}, the provided PathMatcher is not used
103         * and must be configured directly on the custom registry.
104         */
105        public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) {
106                Assert.notNull(subscriptionRegistry, "SubscriptionRegistry must not be null");
107                this.subscriptionRegistry = subscriptionRegistry;
108                initPathMatcherToUse();
109                initCacheLimitToUse();
110                initSelectorHeaderNameToUse();
111        }
112
113        public SubscriptionRegistry getSubscriptionRegistry() {
114                return this.subscriptionRegistry;
115        }
116
117        /**
118         * When configured, the given PathMatcher is passed down to the underlying
119         * SubscriptionRegistry to use for matching destination to subscriptions.
120         * <p>Default is a standard {@link org.springframework.util.AntPathMatcher}.
121         * @since 4.1
122         * @see #setSubscriptionRegistry
123         * @see DefaultSubscriptionRegistry#setPathMatcher
124         * @see org.springframework.util.AntPathMatcher
125         */
126        public void setPathMatcher(@Nullable PathMatcher pathMatcher) {
127                this.pathMatcher = pathMatcher;
128                initPathMatcherToUse();
129        }
130
131        private void initPathMatcherToUse() {
132                if (this.pathMatcher != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
133                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setPathMatcher(this.pathMatcher);
134                }
135        }
136
137        /**
138         * When configured, the specified cache limit is passed down to the
139         * underlying SubscriptionRegistry, overriding any default there.
140         * <p>With a standard {@link DefaultSubscriptionRegistry}, the default
141         * cache limit is 1024.
142         * @since 4.3.2
143         * @see #setSubscriptionRegistry
144         * @see DefaultSubscriptionRegistry#setCacheLimit
145         * @see DefaultSubscriptionRegistry#DEFAULT_CACHE_LIMIT
146         */
147        public void setCacheLimit(@Nullable Integer cacheLimit) {
148                this.cacheLimit = cacheLimit;
149                initCacheLimitToUse();
150        }
151
152        private void initCacheLimitToUse() {
153                if (this.cacheLimit != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
154                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setCacheLimit(this.cacheLimit);
155                }
156        }
157
158        /**
159         * Configure the name of a header that a subscription message can have for
160         * the purpose of filtering messages matched to the subscription. The header
161         * value is expected to be a Spring EL boolean expression to be applied to
162         * the headers of messages matched to the subscription.
163         * <p>For example:
164         * <pre>
165         * headers.foo == 'bar'
166         * </pre>
167         * <p>By default this is set to "selector". You can set it to a different
168         * name, or to {@code null} to turn off support for a selector header.
169         * @param selectorHeaderName the name to use for a selector header
170         * @since 4.3.17
171         * @see #setSubscriptionRegistry
172         * @see DefaultSubscriptionRegistry#setSelectorHeaderName(String)
173         */
174        public void setSelectorHeaderName(@Nullable String selectorHeaderName) {
175                this.selectorHeaderName = selectorHeaderName;
176                initSelectorHeaderNameToUse();
177        }
178
179        private void initSelectorHeaderNameToUse() {
180                if (this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
181                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setSelectorHeaderName(this.selectorHeaderName);
182                }
183        }
184
185        /**
186         * Configure the {@link org.springframework.scheduling.TaskScheduler} to
187         * use for providing heartbeat support. Setting this property also sets the
188         * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000".
189         * <p>By default this is not set.
190         * @since 4.2
191         */
192        public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) {
193                this.taskScheduler = taskScheduler;
194                if (taskScheduler != null && this.heartbeatValue == null) {
195                        this.heartbeatValue = new long[] {10000, 10000};
196                }
197        }
198
199        /**
200         * Return the configured TaskScheduler.
201         * @since 4.2
202         */
203        @Nullable
204        public TaskScheduler getTaskScheduler() {
205                return this.taskScheduler;
206        }
207
208        /**
209         * Configure the value for the heart-beat settings. The first number
210         * represents how often the server will write or send a heartbeat.
211         * The second is how often the client should write. 0 means no heartbeats.
212         * <p>By default this is set to "0, 0" unless the {@link #setTaskScheduler
213         * taskScheduler} in which case the default becomes "10000,10000"
214         * (in milliseconds).
215         * @since 4.2
216         */
217        public void setHeartbeatValue(@Nullable long[] heartbeat) {
218                if (heartbeat != null && (heartbeat.length != 2 || heartbeat[0] < 0 || heartbeat[1] < 0)) {
219                        throw new IllegalArgumentException("Invalid heart-beat: " + Arrays.toString(heartbeat));
220                }
221                this.heartbeatValue = heartbeat;
222        }
223
224        /**
225         * The configured value for the heart-beat settings.
226         * @since 4.2
227         */
228        @Nullable
229        public long[] getHeartbeatValue() {
230                return this.heartbeatValue;
231        }
232
233        /**
234         * Configure a {@link MessageHeaderInitializer} to apply to the headers
235         * of all messages sent to the client outbound channel.
236         * <p>By default this property is not set.
237         * @since 4.1
238         */
239        public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) {
240                this.headerInitializer = headerInitializer;
241        }
242
243        /**
244         * Return the configured header initializer.
245         * @since 4.1
246         */
247        @Nullable
248        public MessageHeaderInitializer getHeaderInitializer() {
249                return this.headerInitializer;
250        }
251
252
253        @Override
254        public void startInternal() {
255                publishBrokerAvailableEvent();
256                if (this.taskScheduler != null) {
257                        long interval = initHeartbeatTaskDelay();
258                        if (interval > 0) {
259                                this.heartbeatFuture = this.taskScheduler.scheduleWithFixedDelay(new HeartbeatTask(), interval);
260                        }
261                }
262                else {
263                        Assert.isTrue(getHeartbeatValue() == null ||
264                                        (getHeartbeatValue()[0] == 0 && getHeartbeatValue()[1] == 0),
265                                        "Heartbeat values configured but no TaskScheduler provided");
266                }
267        }
268
269        private long initHeartbeatTaskDelay() {
270                if (getHeartbeatValue() == null) {
271                        return 0;
272                }
273                else if (getHeartbeatValue()[0] > 0 && getHeartbeatValue()[1] > 0) {
274                        return Math.min(getHeartbeatValue()[0], getHeartbeatValue()[1]);
275                }
276                else {
277                        return (getHeartbeatValue()[0] > 0 ? getHeartbeatValue()[0] : getHeartbeatValue()[1]);
278                }
279        }
280
281        @Override
282        public void stopInternal() {
283                publishBrokerUnavailableEvent();
284                if (this.heartbeatFuture != null) {
285                        this.heartbeatFuture.cancel(true);
286                }
287        }
288
289        @Override
290        protected void handleMessageInternal(Message<?> message) {
291                MessageHeaders headers = message.getHeaders();
292                String destination = SimpMessageHeaderAccessor.getDestination(headers);
293                String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
294
295                updateSessionReadTime(sessionId);
296
297                if (!checkDestinationPrefix(destination)) {
298                        return;
299                }
300
301                SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
302                if (SimpMessageType.MESSAGE.equals(messageType)) {
303                        logMessage(message);
304                        sendMessageToSubscribers(destination, message);
305                }
306                else if (SimpMessageType.CONNECT.equals(messageType)) {
307                        logMessage(message);
308                        if (sessionId != null) {
309                                long[] heartbeatIn = SimpMessageHeaderAccessor.getHeartbeat(headers);
310                                long[] heartbeatOut = getHeartbeatValue();
311                                Principal user = SimpMessageHeaderAccessor.getUser(headers);
312                                MessageChannel outChannel = getClientOutboundChannelForSession(sessionId);
313                                this.sessions.put(sessionId, new SessionInfo(sessionId, user, outChannel, heartbeatIn, heartbeatOut));
314                                SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
315                                initHeaders(connectAck);
316                                connectAck.setSessionId(sessionId);
317                                if (user != null) {
318                                        connectAck.setUser(user);
319                                }
320                                connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
321                                connectAck.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, heartbeatOut);
322                                Message<byte[]> messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders());
323                                getClientOutboundChannel().send(messageOut);
324                        }
325                }
326                else if (SimpMessageType.DISCONNECT.equals(messageType)) {
327                        logMessage(message);
328                        if (sessionId != null) {
329                                Principal user = SimpMessageHeaderAccessor.getUser(headers);
330                                handleDisconnect(sessionId, user, message);
331                        }
332                }
333                else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
334                        logMessage(message);
335                        this.subscriptionRegistry.registerSubscription(message);
336                }
337                else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
338                        logMessage(message);
339                        this.subscriptionRegistry.unregisterSubscription(message);
340                }
341        }
342
343        private void updateSessionReadTime(@Nullable String sessionId) {
344                if (sessionId != null) {
345                        SessionInfo info = this.sessions.get(sessionId);
346                        if (info != null) {
347                                info.setLastReadTime(System.currentTimeMillis());
348                        }
349                }
350        }
351
352        private void logMessage(Message<?> message) {
353                if (logger.isDebugEnabled()) {
354                        SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
355                        accessor = (accessor != null ? accessor : SimpMessageHeaderAccessor.wrap(message));
356                        logger.debug("Processing " + accessor.getShortLogMessage(message.getPayload()));
357                }
358        }
359
360        private void initHeaders(SimpMessageHeaderAccessor accessor) {
361                if (getHeaderInitializer() != null) {
362                        getHeaderInitializer().initHeaders(accessor);
363                }
364        }
365
366        private void handleDisconnect(String sessionId, @Nullable Principal user, @Nullable Message<?> origMessage) {
367                this.sessions.remove(sessionId);
368                this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
369                SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
370                accessor.setSessionId(sessionId);
371                if (user != null) {
372                        accessor.setUser(user);
373                }
374                if (origMessage != null) {
375                        accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage);
376                }
377                initHeaders(accessor);
378                Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
379                getClientOutboundChannel().send(message);
380        }
381
382        protected void sendMessageToSubscribers(@Nullable String destination, Message<?> message) {
383                MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
384                if (!subscriptions.isEmpty() && logger.isDebugEnabled()) {
385                        logger.debug("Broadcasting to " + subscriptions.size() + " sessions.");
386                }
387                long now = System.currentTimeMillis();
388                subscriptions.forEach((sessionId, subscriptionIds) -> {
389                        for (String subscriptionId : subscriptionIds) {
390                                SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
391                                initHeaders(headerAccessor);
392                                headerAccessor.setSessionId(sessionId);
393                                headerAccessor.setSubscriptionId(subscriptionId);
394                                headerAccessor.copyHeadersIfAbsent(message.getHeaders());
395                                headerAccessor.setLeaveMutable(true);
396                                Object payload = message.getPayload();
397                                Message<?> reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
398                                SessionInfo info = this.sessions.get(sessionId);
399                                if (info != null) {
400                                        try {
401                                                info.getClientOutboundChannel().send(reply);
402                                        }
403                                        catch (Throwable ex) {
404                                                if (logger.isErrorEnabled()) {
405                                                        logger.error("Failed to send " + message, ex);
406                                                }
407                                        }
408                                        finally {
409                                                info.setLastWriteTime(now);
410                                        }
411                                }
412                        }
413                });
414        }
415
416        @Override
417        public String toString() {
418                return "SimpleBrokerMessageHandler [" + this.subscriptionRegistry + "]";
419        }
420
421
422        private static class SessionInfo {
423
424                /* STOMP spec: receiver SHOULD take into account an error margin */
425                private static final long HEARTBEAT_MULTIPLIER = 3;
426
427                private final String sessionId;
428
429                @Nullable
430                private final Principal user;
431
432                private final MessageChannel clientOutboundChannel;
433
434                private final long readInterval;
435
436                private final long writeInterval;
437
438                private volatile long lastReadTime;
439
440                private volatile long lastWriteTime;
441
442
443                public SessionInfo(String sessionId, @Nullable Principal user, MessageChannel outboundChannel,
444                                @Nullable long[] clientHeartbeat, @Nullable long[] serverHeartbeat) {
445
446                        this.sessionId = sessionId;
447                        this.user = user;
448                        this.clientOutboundChannel = outboundChannel;
449                        if (clientHeartbeat != null && serverHeartbeat != null) {
450                                this.readInterval = (clientHeartbeat[0] > 0 && serverHeartbeat[1] > 0 ?
451                                                Math.max(clientHeartbeat[0], serverHeartbeat[1]) * HEARTBEAT_MULTIPLIER : 0);
452                                this.writeInterval = (clientHeartbeat[1] > 0 && serverHeartbeat[0] > 0 ?
453                                                Math.max(clientHeartbeat[1], serverHeartbeat[0]) : 0);
454                        }
455                        else {
456                                this.readInterval = 0;
457                                this.writeInterval = 0;
458                        }
459                        this.lastReadTime = this.lastWriteTime = System.currentTimeMillis();
460                }
461
462                public String getSessionId() {
463                        return this.sessionId;
464                }
465
466                @Nullable
467                public Principal getUser() {
468                        return this.user;
469                }
470
471                public MessageChannel getClientOutboundChannel() {
472                        return this.clientOutboundChannel;
473                }
474
475                public long getReadInterval() {
476                        return this.readInterval;
477                }
478
479                public long getWriteInterval() {
480                        return this.writeInterval;
481                }
482
483                public long getLastReadTime() {
484                        return this.lastReadTime;
485                }
486
487                public void setLastReadTime(long lastReadTime) {
488                        this.lastReadTime = lastReadTime;
489                }
490
491                public long getLastWriteTime() {
492                        return this.lastWriteTime;
493                }
494
495                public void setLastWriteTime(long lastWriteTime) {
496                        this.lastWriteTime = lastWriteTime;
497                }
498        }
499
500
501        private class HeartbeatTask implements Runnable {
502
503                @Override
504                public void run() {
505                        long now = System.currentTimeMillis();
506                        for (SessionInfo info : sessions.values()) {
507                                if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) {
508                                        handleDisconnect(info.getSessionId(), info.getUser(), null);
509                                }
510                                if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) {
511                                        SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
512                                        accessor.setSessionId(info.getSessionId());
513                                        Principal user = info.getUser();
514                                        if (user != null) {
515                                                accessor.setUser(user);
516                                        }
517                                        initHeaders(accessor);
518                                        accessor.setLeaveMutable(true);
519                                        MessageHeaders headers = accessor.getMessageHeaders();
520                                        info.getClientOutboundChannel().send(MessageBuilder.createMessage(EMPTY_PAYLOAD, headers));
521                                }
522                        }
523                }
524        }
525
526}