001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.lang.reflect.Type;
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.Date;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ConcurrentHashMap;
026import java.util.concurrent.ExecutionException;
027import java.util.concurrent.ScheduledFuture;
028import java.util.concurrent.atomic.AtomicInteger;
029
030import org.apache.commons.logging.Log;
031import org.apache.commons.logging.LogFactory;
032
033import org.springframework.core.ResolvableType;
034import org.springframework.messaging.Message;
035import org.springframework.messaging.MessageDeliveryException;
036import org.springframework.messaging.converter.MessageConversionException;
037import org.springframework.messaging.converter.MessageConverter;
038import org.springframework.messaging.converter.SimpleMessageConverter;
039import org.springframework.messaging.support.MessageBuilder;
040import org.springframework.messaging.support.MessageHeaderAccessor;
041import org.springframework.messaging.tcp.TcpConnection;
042import org.springframework.scheduling.TaskScheduler;
043import org.springframework.util.AlternativeJdkIdGenerator;
044import org.springframework.util.Assert;
045import org.springframework.util.IdGenerator;
046import org.springframework.util.StringUtils;
047import org.springframework.util.concurrent.ListenableFuture;
048import org.springframework.util.concurrent.ListenableFutureCallback;
049import org.springframework.util.concurrent.SettableListenableFuture;
050
051/**
052 * Default implementation of {@link ConnectionHandlingStompSession}.
053 *
054 * @author Rossen Stoyanchev
055 * @since 4.2
056 */
057public class DefaultStompSession implements ConnectionHandlingStompSession {
058
059        private static final Log logger = LogFactory.getLog(DefaultStompSession.class);
060
061        private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator();
062
063        public static final byte[] EMPTY_PAYLOAD = new byte[0];
064
065        /* STOMP spec: receiver SHOULD take into account an error margin */
066        private static final long HEARTBEAT_MULTIPLIER = 3;
067
068        private static final Message<byte[]> HEARTBEAT;
069
070        static {
071                StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat();
072                HEARTBEAT = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders());
073        }
074
075
076        private final String sessionId;
077
078        private final StompSessionHandler sessionHandler;
079
080        private final StompHeaders connectHeaders;
081
082        private final SettableListenableFuture<StompSession> sessionFuture =
083                        new SettableListenableFuture<StompSession>();
084
085        private MessageConverter converter = new SimpleMessageConverter();
086
087        private TaskScheduler taskScheduler;
088
089        private long receiptTimeLimit = 15 * 1000;
090
091        private volatile boolean autoReceiptEnabled;
092
093
094        private volatile TcpConnection<byte[]> connection;
095
096        private volatile String version;
097
098        private final AtomicInteger subscriptionIndex = new AtomicInteger();
099
100        private final Map<String, DefaultSubscription> subscriptions =
101                        new ConcurrentHashMap<String, DefaultSubscription>(4);
102
103        private final AtomicInteger receiptIndex = new AtomicInteger();
104
105        private final Map<String, ReceiptHandler> receiptHandlers =
106                        new ConcurrentHashMap<String, ReceiptHandler>(4);
107
108        /* Whether the client is willfully closing the connection */
109        private volatile boolean closing = false;
110
111
112        /**
113         * Create a new session.
114         * @param sessionHandler the application handler for the session
115         * @param connectHeaders headers for the STOMP CONNECT frame
116         */
117        public DefaultStompSession(StompSessionHandler sessionHandler, StompHeaders connectHeaders) {
118                Assert.notNull(sessionHandler, "StompSessionHandler must not be null");
119                Assert.notNull(connectHeaders, "StompHeaders must not be null");
120                this.sessionId = idGenerator.generateId().toString();
121                this.sessionHandler = sessionHandler;
122                this.connectHeaders = connectHeaders;
123        }
124
125
126        @Override
127        public String getSessionId() {
128                return this.sessionId;
129        }
130
131        /**
132         * Return the configured session handler.
133         */
134        public StompSessionHandler getSessionHandler() {
135                return this.sessionHandler;
136        }
137
138        @Override
139        public ListenableFuture<StompSession> getSessionFuture() {
140                return this.sessionFuture;
141        }
142
143        /**
144         * Set the {@link MessageConverter} to use to convert the payload of incoming
145         * and outgoing messages to and from {@code byte[]} based on object type, or
146         * expected object type, and the "content-type" header.
147         * <p>By default, {@link SimpleMessageConverter} is configured.
148         * @param messageConverter the message converter to use
149         */
150        public void setMessageConverter(MessageConverter messageConverter) {
151                Assert.notNull(messageConverter, "MessageConverter must not be null");
152                this.converter = messageConverter;
153        }
154
155        /**
156         * Return the configured {@link MessageConverter}.
157         */
158        public MessageConverter getMessageConverter() {
159                return this.converter;
160        }
161
162        /**
163         * Configure the TaskScheduler to use for receipt tracking.
164         */
165        public void setTaskScheduler(TaskScheduler taskScheduler) {
166                this.taskScheduler = taskScheduler;
167        }
168
169        /**
170         * Return the configured TaskScheduler to use for receipt tracking.
171         */
172        public TaskScheduler getTaskScheduler() {
173                return this.taskScheduler;
174        }
175
176        /**
177         * Configure the time in milliseconds before a receipt expires.
178         * <p>By default set to 15,000 (15 seconds).
179         */
180        public void setReceiptTimeLimit(long receiptTimeLimit) {
181                Assert.isTrue(receiptTimeLimit > 0, "Receipt time limit must be larger than zero");
182                this.receiptTimeLimit = receiptTimeLimit;
183        }
184
185        /**
186         * Return the configured time limit before a receipt expires.
187         */
188        public long getReceiptTimeLimit() {
189                return this.receiptTimeLimit;
190        }
191
192        @Override
193        public void setAutoReceipt(boolean autoReceiptEnabled) {
194                this.autoReceiptEnabled = autoReceiptEnabled;
195        }
196
197        /**
198         * Whether receipt headers should be automatically added.
199         */
200        public boolean isAutoReceiptEnabled() {
201                return this.autoReceiptEnabled;
202        }
203
204
205        @Override
206        public boolean isConnected() {
207                return (this.connection != null);
208        }
209
210        @Override
211        public Receiptable send(String destination, Object payload) {
212                StompHeaders stompHeaders = new StompHeaders();
213                stompHeaders.setDestination(destination);
214                return send(stompHeaders, payload);
215        }
216
217        @Override
218        public Receiptable send(StompHeaders stompHeaders, Object payload) {
219                Assert.hasText(stompHeaders.getDestination(), "Destination header is required");
220
221                String receiptId = checkOrAddReceipt(stompHeaders);
222                Receiptable receiptable = new ReceiptHandler(receiptId);
223
224                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SEND);
225                accessor.addNativeHeaders(stompHeaders);
226                Message<byte[]> message = createMessage(accessor, payload);
227                execute(message);
228
229                return receiptable;
230        }
231
232        private String checkOrAddReceipt(StompHeaders stompHeaders) {
233                String receiptId = stompHeaders.getReceipt();
234                if (isAutoReceiptEnabled() && receiptId == null) {
235                        receiptId = String.valueOf(DefaultStompSession.this.receiptIndex.getAndIncrement());
236                        stompHeaders.setReceipt(receiptId);
237                }
238                return receiptId;
239        }
240
241        private StompHeaderAccessor createHeaderAccessor(StompCommand command) {
242                StompHeaderAccessor accessor = StompHeaderAccessor.create(command);
243                accessor.setSessionId(this.sessionId);
244                accessor.setLeaveMutable(true);
245                return accessor;
246        }
247
248        @SuppressWarnings("unchecked")
249        private Message<byte[]> createMessage(StompHeaderAccessor accessor, Object payload) {
250                accessor.updateSimpMessageHeadersFromStompHeaders();
251                Message<byte[]> message;
252                if (payload == null) {
253                        message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
254                }
255                else if (payload instanceof byte[]) {
256                        message = MessageBuilder.createMessage((byte[]) payload, accessor.getMessageHeaders());
257                }
258                else {
259                        message = (Message<byte[]>) getMessageConverter().toMessage(payload, accessor.getMessageHeaders());
260                        accessor.updateStompHeadersFromSimpMessageHeaders();
261                        if (message == null) {
262                                throw new MessageConversionException("Unable to convert payload with type='" +
263                                                payload.getClass().getName() + "', contentType='" + accessor.getContentType() +
264                                                "', converter=[" + getMessageConverter() + "]");
265                        }
266                }
267                return message;
268        }
269
270        private void execute(Message<byte[]> message) {
271                if (logger.isTraceEnabled()) {
272                        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
273                        logger.trace("Sending " + accessor.getDetailedLogMessage(message.getPayload()));
274                }
275                TcpConnection<byte[]> conn = this.connection;
276                Assert.state(conn != null, "Connection closed");
277                try {
278                        conn.send(message).get();
279                }
280                catch (ExecutionException ex) {
281                        throw new MessageDeliveryException(message, ex.getCause());
282                }
283                catch (Throwable ex) {
284                        throw new MessageDeliveryException(message, ex);
285                }
286        }
287
288        @Override
289        public Subscription subscribe(String destination, StompFrameHandler handler) {
290                StompHeaders stompHeaders = new StompHeaders();
291                stompHeaders.setDestination(destination);
292                return subscribe(stompHeaders, handler);
293        }
294
295        @Override
296        public Subscription subscribe(StompHeaders stompHeaders, StompFrameHandler handler) {
297                String destination = stompHeaders.getDestination();
298                Assert.hasText(destination, "Destination header is required");
299                Assert.notNull(handler, "StompFrameHandler must not be null");
300
301                String subscriptionId = stompHeaders.getId();
302                if (!StringUtils.hasText(subscriptionId)) {
303                        subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement());
304                        stompHeaders.setId(subscriptionId);
305                }
306                String receiptId = checkOrAddReceipt(stompHeaders);
307                Subscription subscription = new DefaultSubscription(subscriptionId, destination, receiptId, handler);
308
309                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE);
310                accessor.addNativeHeaders(stompHeaders);
311                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
312                execute(message);
313
314                return subscription;
315        }
316
317        @Override
318        public Receiptable acknowledge(String messageId, boolean consumed) {
319                StompHeaders stompHeaders = new StompHeaders();
320                if ("1.1".equals(this.version)) {
321                        stompHeaders.setMessageId(messageId);
322                }
323                else {
324                        stompHeaders.setId(messageId);
325                }
326
327                String receiptId = checkOrAddReceipt(stompHeaders);
328                Receiptable receiptable = new ReceiptHandler(receiptId);
329
330                StompCommand command = (consumed ? StompCommand.ACK : StompCommand.NACK);
331                StompHeaderAccessor accessor = createHeaderAccessor(command);
332                accessor.addNativeHeaders(stompHeaders);
333                Message<byte[]> message = createMessage(accessor, null);
334                execute(message);
335
336                return receiptable;
337        }
338
339        private void unsubscribe(String id) {
340                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE);
341                accessor.setSubscriptionId(id);
342                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
343                execute(message);
344        }
345
346        @Override
347        public void disconnect() {
348                this.closing = true;
349                try {
350                        StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.DISCONNECT);
351                        Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
352                        execute(message);
353                }
354                finally {
355                        resetConnection();
356                }
357        }
358
359
360        // TcpConnectionHandler
361
362        @Override
363        public void afterConnected(TcpConnection<byte[]> connection) {
364                this.connection = connection;
365                if (logger.isDebugEnabled()) {
366                        logger.debug("Connection established in session id=" + this.sessionId);
367                }
368                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.CONNECT);
369                accessor.addNativeHeaders(this.connectHeaders);
370                accessor.setAcceptVersion("1.1,1.2");
371                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
372                execute(message);
373        }
374
375        @Override
376        public void afterConnectFailure(Throwable ex) {
377                if (logger.isDebugEnabled()) {
378                        logger.debug("Failed to connect session id=" + this.sessionId, ex);
379                }
380                this.sessionFuture.setException(ex);
381                this.sessionHandler.handleTransportError(this, ex);
382        }
383
384        @Override
385        public void handleMessage(Message<byte[]> message) {
386                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
387                accessor.setSessionId(this.sessionId);
388                StompCommand command = accessor.getCommand();
389                Map<String, List<String>> nativeHeaders = accessor.getNativeHeaders();
390                StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(nativeHeaders);
391                boolean isHeartbeat = accessor.isHeartbeat();
392                if (logger.isTraceEnabled()) {
393                        logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload()));
394                }
395                try {
396                        if (StompCommand.MESSAGE.equals(command)) {
397                                DefaultSubscription subscription = this.subscriptions.get(stompHeaders.getSubscription());
398                                if (subscription != null) {
399                                        invokeHandler(subscription.getHandler(), message, stompHeaders);
400                                }
401                                else if (logger.isDebugEnabled()) {
402                                        logger.debug("No handler for: " + accessor.getDetailedLogMessage(message.getPayload()) +
403                                                        ". Perhaps just unsubscribed?");
404                                }
405                        }
406                        else {
407                                if (StompCommand.RECEIPT.equals(command)) {
408                                        String receiptId = stompHeaders.getReceiptId();
409                                        ReceiptHandler handler = this.receiptHandlers.get(receiptId);
410                                        if (handler != null) {
411                                                handler.handleReceiptReceived();
412                                        }
413                                        else if (logger.isDebugEnabled()) {
414                                                logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload()));
415                                        }
416                                }
417                                else if (StompCommand.CONNECTED.equals(command)) {
418                                        initHeartbeatTasks(stompHeaders);
419                                        this.version = stompHeaders.getFirst("version");
420                                        this.sessionFuture.set(this);
421                                        this.sessionHandler.afterConnected(this, stompHeaders);
422                                }
423                                else if (StompCommand.ERROR.equals(command)) {
424                                        invokeHandler(this.sessionHandler, message, stompHeaders);
425                                }
426                                else if (!isHeartbeat && logger.isTraceEnabled()) {
427                                        logger.trace("Message not handled.");
428                                }
429                        }
430                }
431                catch (Throwable ex) {
432                        this.sessionHandler.handleException(this, command, stompHeaders, message.getPayload(), ex);
433                }
434        }
435
436        private void invokeHandler(StompFrameHandler handler, Message<byte[]> message, StompHeaders stompHeaders) {
437                if (message.getPayload().length == 0) {
438                        handler.handleFrame(stompHeaders, null);
439                        return;
440                }
441                Type type = handler.getPayloadType(stompHeaders);
442                Class<?> payloadType = ResolvableType.forType(type).resolve();
443                Object object = getMessageConverter().fromMessage(message, payloadType);
444                if (object == null) {
445                        throw new MessageConversionException("No suitable converter, payloadType=" + payloadType +
446                                        ", handlerType=" + handler.getClass());
447                }
448                handler.handleFrame(stompHeaders, object);
449        }
450
451        private void initHeartbeatTasks(StompHeaders connectedHeaders) {
452                long[] connect = this.connectHeaders.getHeartbeat();
453                long[] connected = connectedHeaders.getHeartbeat();
454                if (connect == null || connected == null) {
455                        return;
456                }
457                if (connect[0] > 0 && connected[1] > 0) {
458                        long interval = Math.max(connect[0],  connected[1]);
459                        this.connection.onWriteInactivity(new WriteInactivityTask(), interval);
460                }
461                if (connect[1] > 0 && connected[0] > 0) {
462                        final long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER;
463                        this.connection.onReadInactivity(new ReadInactivityTask(), interval);
464                }
465        }
466
467        @Override
468        public void handleFailure(Throwable ex) {
469                try {
470                        this.sessionFuture.setException(ex);  // no-op if already set
471                        this.sessionHandler.handleTransportError(this, ex);
472                }
473                catch (Throwable ex2) {
474                        if (logger.isDebugEnabled()) {
475                                logger.debug("Uncaught failure while handling transport failure", ex2);
476                        }
477                }
478        }
479
480        @Override
481        public void afterConnectionClosed() {
482                if (logger.isDebugEnabled()) {
483                        logger.debug("Connection closed in session id=" + this.sessionId);
484                }
485                if (!this.closing) {
486                        resetConnection();
487                        handleFailure(new ConnectionLostException("Connection closed"));
488                }
489        }
490
491        private void resetConnection() {
492                TcpConnection<?> conn = this.connection;
493                this.connection = null;
494                if (conn != null) {
495                        try {
496                                conn.close();
497                        }
498                        catch (Throwable ex) {
499                                // ignore
500                        }
501                }
502        }
503
504
505        private class ReceiptHandler implements Receiptable {
506
507                private final String receiptId;
508
509                private final List<Runnable> receiptCallbacks = new ArrayList<Runnable>(2);
510
511                private final List<Runnable> receiptLostCallbacks = new ArrayList<Runnable>(2);
512
513                private ScheduledFuture<?> future;
514
515                private Boolean result;
516
517                public ReceiptHandler(String receiptId) {
518                        this.receiptId = receiptId;
519                        if (this.receiptId != null) {
520                                initReceiptHandling();
521                        }
522                }
523
524                private void initReceiptHandling() {
525                        Assert.notNull(getTaskScheduler(), "To track receipts, a TaskScheduler must be configured");
526                        DefaultStompSession.this.receiptHandlers.put(this.receiptId, this);
527                        Date startTime = new Date(System.currentTimeMillis() + getReceiptTimeLimit());
528                        this.future = getTaskScheduler().schedule(new Runnable() {
529                                @Override
530                                public void run() {
531                                        handleReceiptNotReceived();
532                                }
533                        }, startTime);
534                }
535
536                @Override
537                public String getReceiptId() {
538                        return this.receiptId;
539                }
540
541                @Override
542                public void addReceiptTask(Runnable task) {
543                        addTask(task, true);
544                }
545
546                @Override
547                public void addReceiptLostTask(Runnable task) {
548                        addTask(task, false);
549                }
550
551                private void addTask(Runnable task, boolean successTask) {
552                        Assert.notNull(this.receiptId,
553                                        "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header");
554                        synchronized (this) {
555                                if (this.result != null && this.result == successTask) {
556                                        invoke(Collections.singletonList(task));
557                                }
558                                else {
559                                        if (successTask) {
560                                                this.receiptCallbacks.add(task);
561                                        }
562                                        else {
563                                                this.receiptLostCallbacks.add(task);
564                                        }
565                                }
566                        }
567                }
568
569                private void invoke(List<Runnable> callbacks) {
570                        for (Runnable runnable : callbacks) {
571                                try {
572                                        runnable.run();
573                                }
574                                catch (Throwable ex) {
575                                        // ignore
576                                }
577                        }
578                }
579
580                public void handleReceiptReceived() {
581                        handleInternal(true);
582                }
583
584                public void handleReceiptNotReceived() {
585                        handleInternal(false);
586                }
587
588                private void handleInternal(boolean result) {
589                        synchronized (this) {
590                                if (this.result != null) {
591                                        return;
592                                }
593                                this.result = result;
594                                invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks);
595                                DefaultStompSession.this.receiptHandlers.remove(this.receiptId);
596                                if (this.future != null) {
597                                        this.future.cancel(true);
598                                }
599                        }
600                }
601        }
602
603
604        private class DefaultSubscription extends ReceiptHandler implements Subscription {
605
606                private final String id;
607
608                private final String destination;
609
610                private final StompFrameHandler handler;
611
612                public DefaultSubscription(String id, String destination, String receiptId, StompFrameHandler handler) {
613                        super(receiptId);
614                        Assert.notNull(destination, "Destination must not be null");
615                        Assert.notNull(handler, "StompFrameHandler must not be null");
616                        this.id = id;
617                        this.destination = destination;
618                        this.handler = handler;
619                        DefaultStompSession.this.subscriptions.put(id, this);
620                }
621
622                @Override
623                public String getSubscriptionId() {
624                        return this.id;
625                }
626
627                public String getDestination() {
628                        return this.destination;
629                }
630
631                public StompFrameHandler getHandler() {
632                        return this.handler;
633                }
634
635                @Override
636                public void unsubscribe() {
637                        DefaultStompSession.this.subscriptions.remove(getSubscriptionId());
638                        DefaultStompSession.this.unsubscribe(getSubscriptionId());
639                }
640
641                @Override
642                public String toString() {
643                        return "Subscription [id=" + getSubscriptionId() + ", destination='" + getDestination() +
644                                        "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]";
645                }
646        }
647
648
649        private class WriteInactivityTask implements Runnable {
650
651                @Override
652                public void run() {
653                        TcpConnection<byte[]> conn = connection;
654                        if (conn != null) {
655                                conn.send(HEARTBEAT).addCallback(
656                                                new ListenableFutureCallback<Void>() {
657                                                        @Override
658                                                        public void onSuccess(Void result) {
659                                                        }
660                                                        @Override
661                                                        public void onFailure(Throwable ex) {
662                                                                handleFailure(ex);
663                                                        }
664                                                });
665                        }
666                }
667        }
668
669
670        private class ReadInactivityTask implements Runnable {
671
672                @Override
673                public void run() {
674                        closing = true;
675                        String error = "Server has gone quiet. Closing connection in session id=" + sessionId + ".";
676                        if (logger.isDebugEnabled()) {
677                                logger.debug(error);
678                        }
679                        resetConnection();
680                        handleFailure(new IllegalStateException(error));
681                }
682        }
683
684}