001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.broker;
018
019import java.security.Principal;
020import java.util.Arrays;
021import java.util.Collection;
022import java.util.List;
023import java.util.Map;
024import java.util.concurrent.ConcurrentHashMap;
025import java.util.concurrent.ScheduledFuture;
026
027import org.springframework.messaging.Message;
028import org.springframework.messaging.MessageChannel;
029import org.springframework.messaging.MessageHeaders;
030import org.springframework.messaging.SubscribableChannel;
031import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
032import org.springframework.messaging.simp.SimpMessageType;
033import org.springframework.messaging.support.MessageBuilder;
034import org.springframework.messaging.support.MessageHeaderAccessor;
035import org.springframework.messaging.support.MessageHeaderInitializer;
036import org.springframework.scheduling.TaskScheduler;
037import org.springframework.util.Assert;
038import org.springframework.util.MultiValueMap;
039import org.springframework.util.PathMatcher;
040
041/**
042 * A "simple" message broker that recognizes the message types defined in
043 * {@link SimpMessageType}, keeps track of subscriptions with the help of a
044 * {@link SubscriptionRegistry} and sends messages to subscribers.
045 *
046 * @author Rossen Stoyanchev
047 * @author Juergen Hoeller
048 * @since 4.0
049 */
050public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
051
052        private static final byte[] EMPTY_PAYLOAD = new byte[0];
053
054
055        private PathMatcher pathMatcher;
056
057        private Integer cacheLimit;
058
059        private String selectorHeaderName = "selector";
060
061        private TaskScheduler taskScheduler;
062
063        private long[] heartbeatValue;
064
065        private MessageHeaderInitializer headerInitializer;
066
067
068        private SubscriptionRegistry subscriptionRegistry;
069
070        private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<String, SessionInfo>();
071
072        private ScheduledFuture<?> heartbeatFuture;
073
074
075        /**
076         * Create a SimpleBrokerMessageHandler instance with the given message channels
077         * and destination prefixes.
078         * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients)
079         * @param clientOutboundChannel the channel for sending messages to clients (e.g. WebSocket clients)
080         * @param brokerChannel the channel for the application to send messages to the broker
081         * @param destinationPrefixes prefixes to use to filter out messages
082         */
083        public SimpleBrokerMessageHandler(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel,
084                        SubscribableChannel brokerChannel, Collection<String> destinationPrefixes) {
085
086                super(clientInboundChannel, clientOutboundChannel, brokerChannel, destinationPrefixes);
087                this.subscriptionRegistry = new DefaultSubscriptionRegistry();
088        }
089
090
091        /**
092         * Configure a custom SubscriptionRegistry to use for storing subscriptions.
093         * <p><strong>Note</strong> that when a custom PathMatcher is configured via
094         * {@link #setPathMatcher}, if the custom registry is not an instance of
095         * {@link DefaultSubscriptionRegistry}, the provided PathMatcher is not used
096         * and must be configured directly on the custom registry.
097         */
098        public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) {
099                Assert.notNull(subscriptionRegistry, "SubscriptionRegistry must not be null");
100                this.subscriptionRegistry = subscriptionRegistry;
101                initPathMatcherToUse();
102                initCacheLimitToUse();
103                initSelectorHeaderNameToUse();
104        }
105
106        public SubscriptionRegistry getSubscriptionRegistry() {
107                return this.subscriptionRegistry;
108        }
109
110        /**
111         * Configure the name of a header that a subscription message can have for
112         * the purpose of filtering messages matched to the subscription. The header
113         * value is expected to be a Spring EL boolean expression to be applied to
114         * the headers of messages matched to the subscription.
115         * <p>For example:
116         * <pre>
117         * headers.foo == 'bar'
118         * </pre>
119         * <p>By default this is set to "selector". You can set it to a different
120         * name, or to {@code null} to turn off support for a selector header.
121         * @param selectorHeaderName the name to use for a selector header
122         * @since 4.3.17
123         * @see #setSubscriptionRegistry
124         * @see DefaultSubscriptionRegistry#setSelectorHeaderName(String)
125         */
126        public void setSelectorHeaderName(String selectorHeaderName) {
127                this.selectorHeaderName = selectorHeaderName;
128                initSelectorHeaderNameToUse();
129        }
130
131        private void initSelectorHeaderNameToUse() {
132                if (this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
133                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setSelectorHeaderName(this.selectorHeaderName);
134                }
135        }
136
137        /**
138         * When configured, the given PathMatcher is passed down to the underlying
139         * SubscriptionRegistry to use for matching destination to subscriptions.
140         * <p>Default is a standard {@link org.springframework.util.AntPathMatcher}.
141         * @since 4.1
142         * @see #setSubscriptionRegistry
143         * @see DefaultSubscriptionRegistry#setPathMatcher
144         * @see org.springframework.util.AntPathMatcher
145         */
146        public void setPathMatcher(PathMatcher pathMatcher) {
147                this.pathMatcher = pathMatcher;
148                initPathMatcherToUse();
149        }
150
151        private void initPathMatcherToUse() {
152                if (this.pathMatcher != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
153                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setPathMatcher(this.pathMatcher);
154                }
155        }
156
157        /**
158         * When configured, the specified cache limit is passed down to the
159         * underlying SubscriptionRegistry, overriding any default there.
160         * <p>With a standard {@link DefaultSubscriptionRegistry}, the default
161         * cache limit is 1024.
162         * @since 4.3.2
163         * @see #setSubscriptionRegistry
164         * @see DefaultSubscriptionRegistry#setCacheLimit
165         * @see DefaultSubscriptionRegistry#DEFAULT_CACHE_LIMIT
166         */
167        public void setCacheLimit(Integer cacheLimit) {
168                this.cacheLimit = cacheLimit;
169                initCacheLimitToUse();
170        }
171
172        private void initCacheLimitToUse() {
173                if (this.cacheLimit != null && this.subscriptionRegistry instanceof DefaultSubscriptionRegistry) {
174                        ((DefaultSubscriptionRegistry) this.subscriptionRegistry).setCacheLimit(this.cacheLimit);
175                }
176        }
177
178        /**
179         * Configure the {@link org.springframework.scheduling.TaskScheduler} to
180         * use for providing heartbeat support. Setting this property also sets the
181         * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000".
182         * <p>By default this is not set.
183         * @since 4.2
184         */
185        public void setTaskScheduler(TaskScheduler taskScheduler) {
186                Assert.notNull(taskScheduler, "TaskScheduler must not be null");
187                this.taskScheduler = taskScheduler;
188                if (this.heartbeatValue == null) {
189                        this.heartbeatValue = new long[] {10000, 10000};
190                }
191        }
192
193        /**
194         * Return the configured TaskScheduler.
195         * @since 4.2
196         */
197        public TaskScheduler getTaskScheduler() {
198                return this.taskScheduler;
199        }
200
201        /**
202         * Configure the value for the heart-beat settings. The first number
203         * represents how often the server will write or send a heartbeat.
204         * The second is how often the client should write. 0 means no heartbeats.
205         * <p>By default this is set to "0, 0" unless the {@link #setTaskScheduler
206         * taskScheduler} in which case the default becomes "10000,10000"
207         * (in milliseconds).
208         * @since 4.2
209         */
210        public void setHeartbeatValue(long[] heartbeat) {
211                if (heartbeat == null || heartbeat.length != 2 || heartbeat[0] < 0 || heartbeat[1] < 0) {
212                        throw new IllegalArgumentException("Invalid heart-beat: " + Arrays.toString(heartbeat));
213                }
214                this.heartbeatValue = heartbeat;
215        }
216
217        /**
218         * The configured value for the heart-beat settings.
219         * @since 4.2
220         */
221        public long[] getHeartbeatValue() {
222                return this.heartbeatValue;
223        }
224
225        /**
226         * Configure a {@link MessageHeaderInitializer} to apply to the headers
227         * of all messages sent to the client outbound channel.
228         * <p>By default this property is not set.
229         * @since 4.1
230         */
231        public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
232                this.headerInitializer = headerInitializer;
233        }
234
235        /**
236         * Return the configured header initializer.
237         * @since 4.1
238         */
239        public MessageHeaderInitializer getHeaderInitializer() {
240                return this.headerInitializer;
241        }
242
243
244        @Override
245        public void startInternal() {
246                publishBrokerAvailableEvent();
247                if (getTaskScheduler() != null) {
248                        long interval = initHeartbeatTaskDelay();
249                        if (interval > 0) {
250                                this.heartbeatFuture = this.taskScheduler.scheduleWithFixedDelay(new HeartbeatTask(), interval);
251                        }
252                }
253                else {
254                        Assert.isTrue(getHeartbeatValue() == null ||
255                                        (getHeartbeatValue()[0] == 0 && getHeartbeatValue()[1] == 0),
256                                        "Heartbeat values configured but no TaskScheduler provided");
257                }
258        }
259
260        private long initHeartbeatTaskDelay() {
261                if (getHeartbeatValue() == null) {
262                        return 0;
263                }
264                else if (getHeartbeatValue()[0] > 0 && getHeartbeatValue()[1] > 0) {
265                        return Math.min(getHeartbeatValue()[0], getHeartbeatValue()[1]);
266                }
267                else {
268                        return (getHeartbeatValue()[0] > 0 ? getHeartbeatValue()[0] : getHeartbeatValue()[1]);
269                }
270        }
271
272        @Override
273        public void stopInternal() {
274                publishBrokerUnavailableEvent();
275                if (this.heartbeatFuture != null) {
276                        this.heartbeatFuture.cancel(true);
277                }
278        }
279
280        @Override
281        protected void handleMessageInternal(Message<?> message) {
282                MessageHeaders headers = message.getHeaders();
283                SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
284                String destination = SimpMessageHeaderAccessor.getDestination(headers);
285                String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
286
287                updateSessionReadTime(sessionId);
288
289                if (!checkDestinationPrefix(destination)) {
290                        return;
291                }
292
293                if (SimpMessageType.MESSAGE.equals(messageType)) {
294                        logMessage(message);
295                        sendMessageToSubscribers(destination, message);
296                }
297                else if (SimpMessageType.CONNECT.equals(messageType)) {
298                        logMessage(message);
299                        long[] clientHeartbeat = SimpMessageHeaderAccessor.getHeartbeat(headers);
300                        long[] serverHeartbeat = getHeartbeatValue();
301                        Principal user = SimpMessageHeaderAccessor.getUser(headers);
302                        this.sessions.put(sessionId, new SessionInfo(sessionId, user, clientHeartbeat, serverHeartbeat));
303                        SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
304                        initHeaders(connectAck);
305                        connectAck.setSessionId(sessionId);
306                        connectAck.setUser(SimpMessageHeaderAccessor.getUser(headers));
307                        connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
308                        connectAck.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, serverHeartbeat);
309                        Message<byte[]> messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders());
310                        getClientOutboundChannel().send(messageOut);
311                }
312                else if (SimpMessageType.DISCONNECT.equals(messageType)) {
313                        logMessage(message);
314                        Principal user = SimpMessageHeaderAccessor.getUser(headers);
315                        handleDisconnect(sessionId, user, message);
316                }
317                else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
318                        logMessage(message);
319                        this.subscriptionRegistry.registerSubscription(message);
320                }
321                else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
322                        logMessage(message);
323                        this.subscriptionRegistry.unregisterSubscription(message);
324                }
325        }
326
327        private void updateSessionReadTime(String sessionId) {
328                if (sessionId != null) {
329                        SessionInfo info = this.sessions.get(sessionId);
330                        if (info != null) {
331                                info.setLastReadTime(System.currentTimeMillis());
332                        }
333                }
334        }
335
336        private void logMessage(Message<?> message) {
337                if (logger.isDebugEnabled()) {
338                        SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
339                        accessor = (accessor != null ? accessor : SimpMessageHeaderAccessor.wrap(message));
340                        logger.debug("Processing " + accessor.getShortLogMessage(message.getPayload()));
341                }
342        }
343
344        private void initHeaders(SimpMessageHeaderAccessor accessor) {
345                if (getHeaderInitializer() != null) {
346                        getHeaderInitializer().initHeaders(accessor);
347                }
348        }
349
350        private void handleDisconnect(String sessionId, Principal user, Message<?> origMessage) {
351                this.sessions.remove(sessionId);
352                this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
353                SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
354                accessor.setSessionId(sessionId);
355                accessor.setUser(user);
356                if (origMessage != null) {
357                        accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage);
358                }
359                initHeaders(accessor);
360                Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
361                getClientOutboundChannel().send(message);
362        }
363
364        protected void sendMessageToSubscribers(String destination, Message<?> message) {
365                MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
366                if (!subscriptions.isEmpty() && logger.isDebugEnabled()) {
367                        logger.debug("Broadcasting to " + subscriptions.size() + " sessions.");
368                }
369                long now = System.currentTimeMillis();
370                for (Map.Entry<String, List<String>> subscriptionEntry : subscriptions.entrySet()) {
371                        for (String subscriptionId : subscriptionEntry.getValue()) {
372                                SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
373                                initHeaders(headerAccessor);
374                                headerAccessor.setSessionId(subscriptionEntry.getKey());
375                                headerAccessor.setSubscriptionId(subscriptionId);
376                                headerAccessor.copyHeadersIfAbsent(message.getHeaders());
377                                Object payload = message.getPayload();
378                                Message<?> reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
379                                try {
380                                        getClientOutboundChannel().send(reply);
381                                }
382                                catch (Throwable ex) {
383                                        if (logger.isErrorEnabled()) {
384                                                logger.error("Failed to send " + message, ex);
385                                        }
386                                }
387                                finally {
388                                        SessionInfo info = this.sessions.get(subscriptionEntry.getKey());
389                                        if (info != null) {
390                                                info.setLastWriteTime(now);
391                                        }
392                                }
393                        }
394                }
395        }
396
397        @Override
398        public String toString() {
399                return "SimpleBrokerMessageHandler [" + this.subscriptionRegistry + "]";
400        }
401
402
403        private static class SessionInfo {
404
405                /* STOMP spec: receiver SHOULD take into account an error margin */
406                private static final long HEARTBEAT_MULTIPLIER = 3;
407
408                private final String sessiondId;
409
410                private final Principal user;
411
412                private final long readInterval;
413
414                private final long writeInterval;
415
416                private volatile long lastReadTime;
417
418                private volatile long lastWriteTime;
419
420                public SessionInfo(String sessiondId, Principal user, long[] clientHeartbeat, long[] serverHeartbeat) {
421                        this.sessiondId = sessiondId;
422                        this.user = user;
423                        if (clientHeartbeat != null && serverHeartbeat != null) {
424                                this.readInterval = (clientHeartbeat[0] > 0 && serverHeartbeat[1] > 0 ?
425                                                Math.max(clientHeartbeat[0], serverHeartbeat[1]) * HEARTBEAT_MULTIPLIER : 0);
426                                this.writeInterval = (clientHeartbeat[1] > 0 && serverHeartbeat[0] > 0 ?
427                                                Math.max(clientHeartbeat[1], serverHeartbeat[0]) : 0);
428                        }
429                        else {
430                                this.readInterval = 0;
431                                this.writeInterval = 0;
432                        }
433                        this.lastReadTime = this.lastWriteTime = System.currentTimeMillis();
434                }
435
436                public String getSessiondId() {
437                        return this.sessiondId;
438                }
439
440                public Principal getUser() {
441                        return this.user;
442                }
443
444                public long getReadInterval() {
445                        return this.readInterval;
446                }
447
448                public long getWriteInterval() {
449                        return this.writeInterval;
450                }
451
452                public long getLastReadTime() {
453                        return this.lastReadTime;
454                }
455
456                public void setLastReadTime(long lastReadTime) {
457                        this.lastReadTime = lastReadTime;
458                }
459
460                public long getLastWriteTime() {
461                        return this.lastWriteTime;
462                }
463
464                public void setLastWriteTime(long lastWriteTime) {
465                        this.lastWriteTime = lastWriteTime;
466                }
467        }
468
469
470        private class HeartbeatTask implements Runnable {
471
472                @Override
473                public void run() {
474                        long now = System.currentTimeMillis();
475                        for (SessionInfo info : sessions.values()) {
476                                if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) {
477                                        handleDisconnect(info.getSessiondId(), info.getUser(), null);
478                                }
479                                if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) {
480                                        SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
481                                        accessor.setSessionId(info.getSessiondId());
482                                        accessor.setUser(info.getUser());
483                                        initHeaders(accessor);
484                                        MessageHeaders headers = accessor.getMessageHeaders();
485                                        getClientOutboundChannel().send(MessageBuilder.createMessage(EMPTY_PAYLOAD, headers));
486                                }
487                        }
488                }
489        }
490
491}