001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.lang.reflect.Type;
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.Date;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ConcurrentHashMap;
026import java.util.concurrent.ExecutionException;
027import java.util.concurrent.ScheduledFuture;
028import java.util.concurrent.TimeUnit;
029import java.util.concurrent.atomic.AtomicInteger;
030
031import org.apache.commons.logging.Log;
032
033import org.springframework.core.ResolvableType;
034import org.springframework.lang.Nullable;
035import org.springframework.messaging.Message;
036import org.springframework.messaging.MessageDeliveryException;
037import org.springframework.messaging.converter.MessageConversionException;
038import org.springframework.messaging.converter.MessageConverter;
039import org.springframework.messaging.converter.SimpleMessageConverter;
040import org.springframework.messaging.simp.SimpLogging;
041import org.springframework.messaging.support.MessageBuilder;
042import org.springframework.messaging.support.MessageHeaderAccessor;
043import org.springframework.messaging.tcp.TcpConnection;
044import org.springframework.scheduling.TaskScheduler;
045import org.springframework.util.AlternativeJdkIdGenerator;
046import org.springframework.util.Assert;
047import org.springframework.util.IdGenerator;
048import org.springframework.util.StringUtils;
049import org.springframework.util.concurrent.ListenableFuture;
050import org.springframework.util.concurrent.ListenableFutureCallback;
051import org.springframework.util.concurrent.SettableListenableFuture;
052
053/**
054 * Default implementation of {@link ConnectionHandlingStompSession}.
055 *
056 * @author Rossen Stoyanchev
057 * @since 4.2
058 */
059public class DefaultStompSession implements ConnectionHandlingStompSession {
060
061        private static final Log logger = SimpLogging.forLogName(DefaultStompSession.class);
062
063        private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator();
064
065        /**
066         * An empty payload.
067         */
068        public static final byte[] EMPTY_PAYLOAD = new byte[0];
069
070        /* STOMP spec: receiver SHOULD take into account an error margin */
071        private static final long HEARTBEAT_MULTIPLIER = 3;
072
073        private static final Message<byte[]> HEARTBEAT;
074
075        static {
076                StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat();
077                HEARTBEAT = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders());
078        }
079
080
081        private final String sessionId;
082
083        private final StompSessionHandler sessionHandler;
084
085        private final StompHeaders connectHeaders;
086
087        private final SettableListenableFuture<StompSession> sessionFuture = new SettableListenableFuture<>();
088
089        private MessageConverter converter = new SimpleMessageConverter();
090
091        @Nullable
092        private TaskScheduler taskScheduler;
093
094        private long receiptTimeLimit = TimeUnit.SECONDS.toMillis(15);
095
096        private volatile boolean autoReceiptEnabled;
097
098
099        @Nullable
100        private volatile TcpConnection<byte[]> connection;
101
102        @Nullable
103        private volatile String version;
104
105        private final AtomicInteger subscriptionIndex = new AtomicInteger();
106
107        private final Map<String, DefaultSubscription> subscriptions = new ConcurrentHashMap<>(4);
108
109        private final AtomicInteger receiptIndex = new AtomicInteger();
110
111        private final Map<String, ReceiptHandler> receiptHandlers = new ConcurrentHashMap<>(4);
112
113        /* Whether the client is willfully closing the connection */
114        private volatile boolean closing;
115
116
117        /**
118         * Create a new session.
119         * @param sessionHandler the application handler for the session
120         * @param connectHeaders headers for the STOMP CONNECT frame
121         */
122        public DefaultStompSession(StompSessionHandler sessionHandler, StompHeaders connectHeaders) {
123                Assert.notNull(sessionHandler, "StompSessionHandler must not be null");
124                Assert.notNull(connectHeaders, "StompHeaders must not be null");
125                this.sessionId = idGenerator.generateId().toString();
126                this.sessionHandler = sessionHandler;
127                this.connectHeaders = connectHeaders;
128        }
129
130
131        @Override
132        public String getSessionId() {
133                return this.sessionId;
134        }
135
136        /**
137         * Return the configured session handler.
138         */
139        public StompSessionHandler getSessionHandler() {
140                return this.sessionHandler;
141        }
142
143        @Override
144        public ListenableFuture<StompSession> getSessionFuture() {
145                return this.sessionFuture;
146        }
147
148        /**
149         * Set the {@link MessageConverter} to use to convert the payload of incoming
150         * and outgoing messages to and from {@code byte[]} based on object type, or
151         * expected object type, and the "content-type" header.
152         * <p>By default, {@link SimpleMessageConverter} is configured.
153         * @param messageConverter the message converter to use
154         */
155        public void setMessageConverter(MessageConverter messageConverter) {
156                Assert.notNull(messageConverter, "MessageConverter must not be null");
157                this.converter = messageConverter;
158        }
159
160        /**
161         * Return the configured {@link MessageConverter}.
162         */
163        public MessageConverter getMessageConverter() {
164                return this.converter;
165        }
166
167        /**
168         * Configure the TaskScheduler to use for receipt tracking.
169         */
170        public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) {
171                this.taskScheduler = taskScheduler;
172        }
173
174        /**
175         * Return the configured TaskScheduler to use for receipt tracking.
176         */
177        @Nullable
178        public TaskScheduler getTaskScheduler() {
179                return this.taskScheduler;
180        }
181
182        /**
183         * Configure the time in milliseconds before a receipt expires.
184         * <p>By default set to 15,000 (15 seconds).
185         */
186        public void setReceiptTimeLimit(long receiptTimeLimit) {
187                Assert.isTrue(receiptTimeLimit > 0, "Receipt time limit must be larger than zero");
188                this.receiptTimeLimit = receiptTimeLimit;
189        }
190
191        /**
192         * Return the configured time limit before a receipt expires.
193         */
194        public long getReceiptTimeLimit() {
195                return this.receiptTimeLimit;
196        }
197
198        @Override
199        public void setAutoReceipt(boolean autoReceiptEnabled) {
200                this.autoReceiptEnabled = autoReceiptEnabled;
201        }
202
203        /**
204         * Whether receipt headers should be automatically added.
205         */
206        public boolean isAutoReceiptEnabled() {
207                return this.autoReceiptEnabled;
208        }
209
210
211        @Override
212        public boolean isConnected() {
213                return (this.connection != null);
214        }
215
216        @Override
217        public Receiptable send(String destination, Object payload) {
218                StompHeaders headers = new StompHeaders();
219                headers.setDestination(destination);
220                return send(headers, payload);
221        }
222
223        @Override
224        public Receiptable send(StompHeaders headers, Object payload) {
225                Assert.hasText(headers.getDestination(), "Destination header is required");
226
227                String receiptId = checkOrAddReceipt(headers);
228                Receiptable receiptable = new ReceiptHandler(receiptId);
229
230                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SEND);
231                accessor.addNativeHeaders(headers);
232                Message<byte[]> message = createMessage(accessor, payload);
233                execute(message);
234
235                return receiptable;
236        }
237
238        @Nullable
239        private String checkOrAddReceipt(StompHeaders headers) {
240                String receiptId = headers.getReceipt();
241                if (isAutoReceiptEnabled() && receiptId == null) {
242                        receiptId = String.valueOf(DefaultStompSession.this.receiptIndex.getAndIncrement());
243                        headers.setReceipt(receiptId);
244                }
245                return receiptId;
246        }
247
248        private StompHeaderAccessor createHeaderAccessor(StompCommand command) {
249                StompHeaderAccessor accessor = StompHeaderAccessor.create(command);
250                accessor.setSessionId(this.sessionId);
251                accessor.setLeaveMutable(true);
252                return accessor;
253        }
254
255        @SuppressWarnings("unchecked")
256        private Message<byte[]> createMessage(StompHeaderAccessor accessor, @Nullable Object payload) {
257                accessor.updateSimpMessageHeadersFromStompHeaders();
258                Message<byte[]> message;
259                if (StringUtils.isEmpty(payload) || (payload instanceof byte[] && ((byte[]) payload).length == 0)) {
260                        message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
261                }
262                else {
263                        message = (Message<byte[]>) getMessageConverter().toMessage(payload, accessor.getMessageHeaders());
264                        accessor.updateStompHeadersFromSimpMessageHeaders();
265                        if (message == null) {
266                                throw new MessageConversionException("Unable to convert payload with type='" +
267                                                payload.getClass().getName() + "', contentType='" + accessor.getContentType() +
268                                                "', converter=[" + getMessageConverter() + "]");
269                        }
270                }
271                return message;
272        }
273
274        private void execute(Message<byte[]> message) {
275                if (logger.isTraceEnabled()) {
276                        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
277                        if (accessor != null) {
278                                logger.trace("Sending " + accessor.getDetailedLogMessage(message.getPayload()));
279                        }
280                }
281                TcpConnection<byte[]> conn = this.connection;
282                Assert.state(conn != null, "Connection closed");
283                try {
284                        conn.send(message).get();
285                }
286                catch (ExecutionException ex) {
287                        throw new MessageDeliveryException(message, ex.getCause());
288                }
289                catch (Throwable ex) {
290                        throw new MessageDeliveryException(message, ex);
291                }
292        }
293
294        @Override
295        public Subscription subscribe(String destination, StompFrameHandler handler) {
296                StompHeaders headers = new StompHeaders();
297                headers.setDestination(destination);
298                return subscribe(headers, handler);
299        }
300
301        @Override
302        public Subscription subscribe(StompHeaders headers, StompFrameHandler handler) {
303                Assert.hasText(headers.getDestination(), "Destination header is required");
304                Assert.notNull(handler, "StompFrameHandler must not be null");
305
306                String subscriptionId = headers.getId();
307                if (!StringUtils.hasText(subscriptionId)) {
308                        subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement());
309                        headers.setId(subscriptionId);
310                }
311                checkOrAddReceipt(headers);
312                Subscription subscription = new DefaultSubscription(headers, handler);
313
314                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE);
315                accessor.addNativeHeaders(headers);
316                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
317                execute(message);
318
319                return subscription;
320        }
321
322        @Override
323        public Receiptable acknowledge(String messageId, boolean consumed) {
324                StompHeaders headers = new StompHeaders();
325                if ("1.1".equals(this.version)) {
326                        headers.setMessageId(messageId);
327                }
328                else {
329                        headers.setId(messageId);
330                }
331                return acknowledge(headers, consumed);
332        }
333
334        @Override
335        public Receiptable acknowledge(StompHeaders headers, boolean consumed) {
336                String receiptId = checkOrAddReceipt(headers);
337                Receiptable receiptable = new ReceiptHandler(receiptId);
338
339                StompCommand command = (consumed ? StompCommand.ACK : StompCommand.NACK);
340                StompHeaderAccessor accessor = createHeaderAccessor(command);
341                accessor.addNativeHeaders(headers);
342                Message<byte[]> message = createMessage(accessor, null);
343                execute(message);
344
345                return receiptable;
346        }
347
348        private void unsubscribe(String id, @Nullable StompHeaders headers) {
349                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE);
350                if (headers != null) {
351                        accessor.addNativeHeaders(headers);
352                }
353                accessor.setSubscriptionId(id);
354                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
355                execute(message);
356        }
357
358        @Override
359        public void disconnect() {
360                disconnect(null);
361        }
362
363        @Override
364        public void disconnect(@Nullable StompHeaders headers) {
365                this.closing = true;
366                try {
367                        StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.DISCONNECT);
368                        if (headers != null) {
369                                accessor.addNativeHeaders(headers);
370                        }
371                        Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
372                        execute(message);
373                }
374                finally {
375                        resetConnection();
376                }
377        }
378
379
380        // TcpConnectionHandler
381
382        @Override
383        public void afterConnected(TcpConnection<byte[]> connection) {
384                this.connection = connection;
385                if (logger.isDebugEnabled()) {
386                        logger.debug("Connection established in session id=" + this.sessionId);
387                }
388                StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.CONNECT);
389                accessor.addNativeHeaders(this.connectHeaders);
390                if (this.connectHeaders.getAcceptVersion() == null) {
391                        accessor.setAcceptVersion("1.1,1.2");
392                }
393                Message<byte[]> message = createMessage(accessor, EMPTY_PAYLOAD);
394                execute(message);
395        }
396
397        @Override
398        public void afterConnectFailure(Throwable ex) {
399                if (logger.isDebugEnabled()) {
400                        logger.debug("Failed to connect session id=" + this.sessionId, ex);
401                }
402                this.sessionFuture.setException(ex);
403                this.sessionHandler.handleTransportError(this, ex);
404        }
405
406        @Override
407        public void handleMessage(Message<byte[]> message) {
408                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
409                Assert.state(accessor != null, "No StompHeaderAccessor");
410
411                accessor.setSessionId(this.sessionId);
412                StompCommand command = accessor.getCommand();
413                Map<String, List<String>> nativeHeaders = accessor.getNativeHeaders();
414                StompHeaders headers = StompHeaders.readOnlyStompHeaders(nativeHeaders);
415                boolean isHeartbeat = accessor.isHeartbeat();
416                if (logger.isTraceEnabled()) {
417                        logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload()));
418                }
419
420                try {
421                        if (StompCommand.MESSAGE.equals(command)) {
422                                DefaultSubscription subscription = this.subscriptions.get(headers.getSubscription());
423                                if (subscription != null) {
424                                        invokeHandler(subscription.getHandler(), message, headers);
425                                }
426                                else if (logger.isDebugEnabled()) {
427                                        logger.debug("No handler for: " + accessor.getDetailedLogMessage(message.getPayload()) +
428                                                        ". Perhaps just unsubscribed?");
429                                }
430                        }
431                        else {
432                                if (StompCommand.RECEIPT.equals(command)) {
433                                        String receiptId = headers.getReceiptId();
434                                        ReceiptHandler handler = this.receiptHandlers.get(receiptId);
435                                        if (handler != null) {
436                                                handler.handleReceiptReceived();
437                                        }
438                                        else if (logger.isDebugEnabled()) {
439                                                logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload()));
440                                        }
441                                }
442                                else if (StompCommand.CONNECTED.equals(command)) {
443                                        initHeartbeatTasks(headers);
444                                        this.version = headers.getFirst("version");
445                                        this.sessionFuture.set(this);
446                                        this.sessionHandler.afterConnected(this, headers);
447                                }
448                                else if (StompCommand.ERROR.equals(command)) {
449                                        invokeHandler(this.sessionHandler, message, headers);
450                                }
451                                else if (!isHeartbeat && logger.isTraceEnabled()) {
452                                        logger.trace("Message not handled.");
453                                }
454                        }
455                }
456                catch (Throwable ex) {
457                        this.sessionHandler.handleException(this, command, headers, message.getPayload(), ex);
458                }
459        }
460
461        private void invokeHandler(StompFrameHandler handler, Message<byte[]> message, StompHeaders headers) {
462                if (message.getPayload().length == 0) {
463                        handler.handleFrame(headers, null);
464                        return;
465                }
466                Type payloadType = handler.getPayloadType(headers);
467                Class<?> resolvedType = ResolvableType.forType(payloadType).resolve();
468                if (resolvedType == null) {
469                        throw new MessageConversionException("Unresolvable payload type [" + payloadType +
470                                        "] from handler type [" + handler.getClass() + "]");
471                }
472                Object object = getMessageConverter().fromMessage(message, resolvedType);
473                if (object == null) {
474                        throw new MessageConversionException("No suitable converter for payload type [" + payloadType +
475                                        "] from handler type [" + handler.getClass() + "]");
476                }
477                handler.handleFrame(headers, object);
478        }
479
480        private void initHeartbeatTasks(StompHeaders connectedHeaders) {
481                long[] connect = this.connectHeaders.getHeartbeat();
482                long[] connected = connectedHeaders.getHeartbeat();
483                if (connect == null || connected == null) {
484                        return;
485                }
486                TcpConnection<byte[]> con = this.connection;
487                Assert.state(con != null, "No TcpConnection available");
488                if (connect[0] > 0 && connected[1] > 0) {
489                        long interval = Math.max(connect[0],  connected[1]);
490                        con.onWriteInactivity(new WriteInactivityTask(), interval);
491                }
492                if (connect[1] > 0 && connected[0] > 0) {
493                        long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER;
494                        con.onReadInactivity(new ReadInactivityTask(), interval);
495                }
496        }
497
498        @Override
499        public void handleFailure(Throwable ex) {
500                try {
501                        this.sessionFuture.setException(ex);  // no-op if already set
502                        this.sessionHandler.handleTransportError(this, ex);
503                }
504                catch (Throwable ex2) {
505                        if (logger.isDebugEnabled()) {
506                                logger.debug("Uncaught failure while handling transport failure", ex2);
507                        }
508                }
509        }
510
511        @Override
512        public void afterConnectionClosed() {
513                if (logger.isDebugEnabled()) {
514                        logger.debug("Connection closed in session id=" + this.sessionId);
515                }
516                if (!this.closing) {
517                        resetConnection();
518                        handleFailure(new ConnectionLostException("Connection closed"));
519                }
520        }
521
522        private void resetConnection() {
523                TcpConnection<?> conn = this.connection;
524                this.connection = null;
525                if (conn != null) {
526                        try {
527                                conn.close();
528                        }
529                        catch (Throwable ex) {
530                                // ignore
531                        }
532                }
533        }
534
535
536        private class ReceiptHandler implements Receiptable {
537
538                @Nullable
539                private final String receiptId;
540
541                private final List<Runnable> receiptCallbacks = new ArrayList<>(2);
542
543                private final List<Runnable> receiptLostCallbacks = new ArrayList<>(2);
544
545                @Nullable
546                private ScheduledFuture<?> future;
547
548                @Nullable
549                private Boolean result;
550
551                public ReceiptHandler(@Nullable String receiptId) {
552                        this.receiptId = receiptId;
553                        if (receiptId != null) {
554                                initReceiptHandling();
555                        }
556                }
557
558                private void initReceiptHandling() {
559                        Assert.notNull(getTaskScheduler(), "To track receipts, a TaskScheduler must be configured");
560                        DefaultStompSession.this.receiptHandlers.put(this.receiptId, this);
561                        Date startTime = new Date(System.currentTimeMillis() + getReceiptTimeLimit());
562                        this.future = getTaskScheduler().schedule(this::handleReceiptNotReceived, startTime);
563                }
564
565                @Override
566                @Nullable
567                public String getReceiptId() {
568                        return this.receiptId;
569                }
570
571                @Override
572                public void addReceiptTask(Runnable task) {
573                        addTask(task, true);
574                }
575
576                @Override
577                public void addReceiptLostTask(Runnable task) {
578                        addTask(task, false);
579                }
580
581                private void addTask(Runnable task, boolean successTask) {
582                        Assert.notNull(this.receiptId,
583                                        "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header");
584                        synchronized (this) {
585                                if (this.result != null && this.result == successTask) {
586                                        invoke(Collections.singletonList(task));
587                                }
588                                else {
589                                        if (successTask) {
590                                                this.receiptCallbacks.add(task);
591                                        }
592                                        else {
593                                                this.receiptLostCallbacks.add(task);
594                                        }
595                                }
596                        }
597                }
598
599                private void invoke(List<Runnable> callbacks) {
600                        for (Runnable runnable : callbacks) {
601                                try {
602                                        runnable.run();
603                                }
604                                catch (Throwable ex) {
605                                        // ignore
606                                }
607                        }
608                }
609
610                public void handleReceiptReceived() {
611                        handleInternal(true);
612                }
613
614                public void handleReceiptNotReceived() {
615                        handleInternal(false);
616                }
617
618                private void handleInternal(boolean result) {
619                        synchronized (this) {
620                                if (this.result != null) {
621                                        return;
622                                }
623                                this.result = result;
624                                invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks);
625                                DefaultStompSession.this.receiptHandlers.remove(this.receiptId);
626                                if (this.future != null) {
627                                        this.future.cancel(true);
628                                }
629                        }
630                }
631        }
632
633
634        private class DefaultSubscription extends ReceiptHandler implements Subscription {
635
636                private final StompHeaders headers;
637
638                private final StompFrameHandler handler;
639
640                public DefaultSubscription(StompHeaders headers, StompFrameHandler handler) {
641                        super(headers.getReceipt());
642                        Assert.notNull(headers.getDestination(), "Destination must not be null");
643                        Assert.notNull(handler, "StompFrameHandler must not be null");
644                        this.headers = headers;
645                        this.handler = handler;
646                        DefaultStompSession.this.subscriptions.put(headers.getId(), this);
647                }
648
649                @Override
650                @Nullable
651                public String getSubscriptionId() {
652                        return this.headers.getId();
653                }
654
655                @Override
656                public StompHeaders getSubscriptionHeaders() {
657                        return this.headers;
658                }
659
660                public StompFrameHandler getHandler() {
661                        return this.handler;
662                }
663
664                @Override
665                public void unsubscribe() {
666                        unsubscribe(null);
667                }
668
669                @Override
670                public void unsubscribe(@Nullable StompHeaders headers) {
671                        String id = this.headers.getId();
672                        if (id != null) {
673                                DefaultStompSession.this.subscriptions.remove(id);
674                                DefaultStompSession.this.unsubscribe(id, headers);
675                        }
676                }
677
678                @Override
679                public String toString() {
680                        return "Subscription [id=" + getSubscriptionId() +
681                                        ", destination='" + this.headers.getDestination() +
682                                        "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]";
683                }
684        }
685
686
687        private class WriteInactivityTask implements Runnable {
688
689                @Override
690                public void run() {
691                        TcpConnection<byte[]> conn = connection;
692                        if (conn != null) {
693                                conn.send(HEARTBEAT).addCallback(
694                                                new ListenableFutureCallback<Void>() {
695                                                        @Override
696                                                        public void onSuccess(@Nullable Void result) {
697                                                        }
698                                                        @Override
699                                                        public void onFailure(Throwable ex) {
700                                                                handleFailure(ex);
701                                                        }
702                                                });
703                        }
704                }
705        }
706
707
708        private class ReadInactivityTask implements Runnable {
709
710                @Override
711                public void run() {
712                        closing = true;
713                        String error = "Server has gone quiet. Closing connection in session id=" + sessionId + ".";
714                        if (logger.isDebugEnabled()) {
715                                logger.debug(error);
716                        }
717                        resetConnection();
718                        handleFailure(new IllegalStateException(error));
719                }
720        }
721
722}