001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.io.IOException;
020import java.net.URI;
021import java.nio.ByteBuffer;
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.List;
025import java.util.concurrent.ScheduledFuture;
026
027import org.apache.commons.logging.Log;
028import org.apache.commons.logging.LogFactory;
029
030import org.springframework.context.Lifecycle;
031import org.springframework.context.SmartLifecycle;
032import org.springframework.lang.Nullable;
033import org.springframework.messaging.Message;
034import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
035import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
036import org.springframework.messaging.simp.stomp.StompClientSupport;
037import org.springframework.messaging.simp.stomp.StompDecoder;
038import org.springframework.messaging.simp.stomp.StompEncoder;
039import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
040import org.springframework.messaging.simp.stomp.StompHeaders;
041import org.springframework.messaging.simp.stomp.StompSession;
042import org.springframework.messaging.simp.stomp.StompSessionHandler;
043import org.springframework.messaging.support.MessageHeaderAccessor;
044import org.springframework.messaging.tcp.TcpConnection;
045import org.springframework.messaging.tcp.TcpConnectionHandler;
046import org.springframework.scheduling.TaskScheduler;
047import org.springframework.util.Assert;
048import org.springframework.util.MimeTypeUtils;
049import org.springframework.util.concurrent.ListenableFuture;
050import org.springframework.util.concurrent.ListenableFutureCallback;
051import org.springframework.util.concurrent.SettableListenableFuture;
052import org.springframework.web.socket.BinaryMessage;
053import org.springframework.web.socket.CloseStatus;
054import org.springframework.web.socket.TextMessage;
055import org.springframework.web.socket.WebSocketHandler;
056import org.springframework.web.socket.WebSocketHttpHeaders;
057import org.springframework.web.socket.WebSocketMessage;
058import org.springframework.web.socket.WebSocketSession;
059import org.springframework.web.socket.client.WebSocketClient;
060import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator;
061import org.springframework.web.socket.sockjs.transport.SockJsSession;
062import org.springframework.web.util.UriComponentsBuilder;
063
064/**
065 * A STOMP over WebSocket client that connects using an implementation of
066 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
067 * including {@link org.springframework.web.socket.sockjs.client.SockJsClient
068 * SockJsClient}.
069 *
070 * @author Rossen Stoyanchev
071 * @since 4.2
072 */
073public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle {
074
075        private static final Log logger = LogFactory.getLog(WebSocketStompClient.class);
076
077        private final WebSocketClient webSocketClient;
078
079        private int inboundMessageSizeLimit = 64 * 1024;
080
081        private boolean autoStartup = true;
082
083        private int phase = DEFAULT_PHASE;
084
085        private volatile boolean running = false;
086
087
088        /**
089         * Class constructor. Sets {@link #setDefaultHeartbeat} to "0,0" but will
090         * reset it back to the preferred "10000,10000" when a
091         * {@link #setTaskScheduler} is configured.
092         * @param webSocketClient the WebSocket client to connect with
093         */
094        public WebSocketStompClient(WebSocketClient webSocketClient) {
095                Assert.notNull(webSocketClient, "WebSocketClient is required");
096                this.webSocketClient = webSocketClient;
097                setDefaultHeartbeat(new long[] {0, 0});
098        }
099
100
101        /**
102         * Return the configured WebSocketClient.
103         */
104        public WebSocketClient getWebSocketClient() {
105                return this.webSocketClient;
106        }
107
108        /**
109         * {@inheritDoc}
110         * <p>Also automatically sets the {@link #setDefaultHeartbeat defaultHeartbeat}
111         * property to "10000,10000" if it is currently set to "0,0".
112         */
113        @Override
114        public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) {
115                if (!isDefaultHeartbeatEnabled()) {
116                        setDefaultHeartbeat(new long[] {10000, 10000});
117                }
118                super.setTaskScheduler(taskScheduler);
119        }
120
121        /**
122         * Configure the maximum size allowed for inbound STOMP message.
123         * Since a STOMP message can be received in multiple WebSocket messages,
124         * buffering may be required and this property determines the maximum buffer
125         * size per message.
126         * <p>By default this is set to 64 * 1024 (64K).
127         */
128        public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) {
129                this.inboundMessageSizeLimit = inboundMessageSizeLimit;
130        }
131
132        /**
133         * Get the configured inbound message buffer size in bytes.
134         */
135        public int getInboundMessageSizeLimit() {
136                return this.inboundMessageSizeLimit;
137        }
138
139        /**
140         * Set whether to auto-start the contained WebSocketClient when the Spring
141         * context has been refreshed.
142         * <p>Default is "true".
143         */
144        public void setAutoStartup(boolean autoStartup) {
145                this.autoStartup = autoStartup;
146        }
147
148        /**
149         * Return the value for the 'autoStartup' property. If "true", this client
150         * will automatically start and stop the contained WebSocketClient.
151         */
152        @Override
153        public boolean isAutoStartup() {
154                return this.autoStartup;
155        }
156
157        /**
158         * Specify the phase in which the WebSocket client should be started and
159         * subsequently closed. The startup order proceeds from lowest to highest,
160         * and the shutdown order is the reverse of that.
161         * <p>By default this is Integer.MAX_VALUE meaning that the WebSocket client
162         * is started as late as possible and stopped as soon as possible.
163         */
164        public void setPhase(int phase) {
165                this.phase = phase;
166        }
167
168        /**
169         * Return the configured phase.
170         */
171        @Override
172        public int getPhase() {
173                return this.phase;
174        }
175
176
177        @Override
178        public void start() {
179                if (!isRunning()) {
180                        this.running = true;
181                        if (getWebSocketClient() instanceof Lifecycle) {
182                                ((Lifecycle) getWebSocketClient()).start();
183                        }
184                }
185
186        }
187
188        @Override
189        public void stop() {
190                if (isRunning()) {
191                        this.running = false;
192                        if (getWebSocketClient() instanceof Lifecycle) {
193                                ((Lifecycle) getWebSocketClient()).stop();
194                        }
195                }
196        }
197
198        @Override
199        public boolean isRunning() {
200                return this.running;
201        }
202
203
204        /**
205         * Connect to the given WebSocket URL and notify the given
206         * {@link org.springframework.messaging.simp.stomp.StompSessionHandler}
207         * when connected on the STOMP level after the CONNECTED frame is received.
208         * @param url the url to connect to
209         * @param handler the session handler
210         * @param uriVars the URI variables to expand into the URL
211         * @return a ListenableFuture for access to the session when ready for use
212         */
213        public ListenableFuture<StompSession> connect(String url, StompSessionHandler handler, Object... uriVars) {
214                return connect(url, null, handler, uriVars);
215        }
216
217        /**
218         * An overloaded version of
219         * {@link #connect(String, StompSessionHandler, Object...)} that also
220         * accepts {@link WebSocketHttpHeaders} to use for the WebSocket handshake.
221         * @param url the url to connect to
222         * @param handshakeHeaders the headers for the WebSocket handshake
223         * @param handler the session handler
224         * @param uriVariables the URI variables to expand into the URL
225         * @return a ListenableFuture for access to the session when ready for use
226         */
227        public ListenableFuture<StompSession> connect(String url, @Nullable WebSocketHttpHeaders handshakeHeaders,
228                        StompSessionHandler handler, Object... uriVariables) {
229
230                return connect(url, handshakeHeaders, null, handler, uriVariables);
231        }
232
233        /**
234         * An overloaded version of
235         * {@link #connect(String, StompSessionHandler, Object...)} that also accepts
236         * {@link WebSocketHttpHeaders} to use for the WebSocket handshake and
237         * {@link StompHeaders} for the STOMP CONNECT frame.
238         * @param url the url to connect to
239         * @param handshakeHeaders headers for the WebSocket handshake
240         * @param connectHeaders headers for the STOMP CONNECT frame
241         * @param handler the session handler
242         * @param uriVariables the URI variables to expand into the URL
243         * @return a ListenableFuture for access to the session when ready for use
244         */
245        public ListenableFuture<StompSession> connect(String url, @Nullable WebSocketHttpHeaders handshakeHeaders,
246                        @Nullable StompHeaders connectHeaders, StompSessionHandler handler, Object... uriVariables) {
247
248                Assert.notNull(url, "'url' must not be null");
249                URI uri = UriComponentsBuilder.fromUriString(url).buildAndExpand(uriVariables).encode().toUri();
250                return connect(uri, handshakeHeaders, connectHeaders, handler);
251        }
252
253        /**
254         * An overloaded version of
255         * {@link #connect(String, WebSocketHttpHeaders, StompSessionHandler, Object...)}
256         * that accepts a fully prepared {@link java.net.URI}.
257         * @param url the url to connect to
258         * @param handshakeHeaders the headers for the WebSocket handshake
259         * @param connectHeaders headers for the STOMP CONNECT frame
260         * @param sessionHandler the STOMP session handler
261         * @return a ListenableFuture for access to the session when ready for use
262         */
263        public ListenableFuture<StompSession> connect(URI url, @Nullable WebSocketHttpHeaders handshakeHeaders,
264                        @Nullable StompHeaders connectHeaders, StompSessionHandler sessionHandler) {
265
266                Assert.notNull(url, "'url' must not be null");
267                ConnectionHandlingStompSession session = createSession(connectHeaders, sessionHandler);
268                WebSocketTcpConnectionHandlerAdapter adapter = new WebSocketTcpConnectionHandlerAdapter(session);
269                getWebSocketClient()
270                                .doHandshake(new LoggingWebSocketHandlerDecorator(adapter), handshakeHeaders, url)
271                                .addCallback(adapter);
272                return session.getSessionFuture();
273        }
274
275        @Override
276        protected StompHeaders processConnectHeaders(@Nullable StompHeaders connectHeaders) {
277                connectHeaders = super.processConnectHeaders(connectHeaders);
278                if (connectHeaders.isHeartbeatEnabled()) {
279                        Assert.state(getTaskScheduler() != null, "TaskScheduler must be set if heartbeats are enabled");
280                }
281                return connectHeaders;
282        }
283
284
285        /**
286         * Adapt WebSocket to the TcpConnectionHandler and TcpConnection contracts.
287         */
288        private class WebSocketTcpConnectionHandlerAdapter implements ListenableFutureCallback<WebSocketSession>,
289                        WebSocketHandler, TcpConnection<byte[]> {
290
291                private final TcpConnectionHandler<byte[]> connectionHandler;
292
293                private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit());
294
295                @Nullable
296                private volatile WebSocketSession session;
297
298                private volatile long lastReadTime = -1;
299
300                private volatile long lastWriteTime = -1;
301
302                private final List<ScheduledFuture<?>> inactivityTasks = new ArrayList<>(2);
303
304                public WebSocketTcpConnectionHandlerAdapter(TcpConnectionHandler<byte[]> connectionHandler) {
305                        Assert.notNull(connectionHandler, "TcpConnectionHandler must not be null");
306                        this.connectionHandler = connectionHandler;
307                }
308
309                // ListenableFutureCallback implementation: handshake outcome
310
311                @Override
312                public void onSuccess(@Nullable WebSocketSession webSocketSession) {
313                }
314
315                @Override
316                public void onFailure(Throwable ex) {
317                        this.connectionHandler.afterConnectFailure(ex);
318                }
319
320                // WebSocketHandler implementation
321
322                @Override
323                public void afterConnectionEstablished(WebSocketSession session) {
324                        this.session = session;
325                        this.connectionHandler.afterConnected(this);
326                }
327
328                @Override
329                public void handleMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) {
330                        this.lastReadTime = (this.lastReadTime != -1 ? System.currentTimeMillis() : -1);
331                        List<Message<byte[]>> messages;
332                        try {
333                                messages = this.codec.decode(webSocketMessage);
334                        }
335                        catch (Throwable ex) {
336                                this.connectionHandler.handleFailure(ex);
337                                return;
338                        }
339                        for (Message<byte[]> message : messages) {
340                                this.connectionHandler.handleMessage(message);
341                        }
342                }
343
344                @Override
345                public void handleTransportError(WebSocketSession session, Throwable ex) throws Exception {
346                        this.connectionHandler.handleFailure(ex);
347                }
348
349                @Override
350                public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
351                        cancelInactivityTasks();
352                        this.connectionHandler.afterConnectionClosed();
353                }
354
355                private void cancelInactivityTasks() {
356                        for (ScheduledFuture<?> task : this.inactivityTasks) {
357                                try {
358                                        task.cancel(true);
359                                }
360                                catch (Throwable ex) {
361                                        // Ignore
362                                }
363                        }
364                        this.lastReadTime = -1;
365                        this.lastWriteTime = -1;
366                        this.inactivityTasks.clear();
367                }
368
369                @Override
370                public boolean supportsPartialMessages() {
371                        return false;
372                }
373
374                // TcpConnection implementation
375
376                @Override
377                public ListenableFuture<Void> send(Message<byte[]> message) {
378                        updateLastWriteTime();
379                        SettableListenableFuture<Void> future = new SettableListenableFuture<>();
380                        try {
381                                WebSocketSession session = this.session;
382                                Assert.state(session != null, "No WebSocketSession available");
383                                session.sendMessage(this.codec.encode(message, session.getClass()));
384                                future.set(null);
385                        }
386                        catch (Throwable ex) {
387                                future.setException(ex);
388                        }
389                        finally {
390                                updateLastWriteTime();
391                        }
392                        return future;
393                }
394
395                private void updateLastWriteTime() {
396                        long lastWriteTime = this.lastWriteTime;
397                        if (lastWriteTime != -1) {
398                                this.lastWriteTime = System.currentTimeMillis();
399                        }
400                }
401
402                @Override
403                public void onReadInactivity(final Runnable runnable, final long duration) {
404                        Assert.state(getTaskScheduler() != null, "No TaskScheduler configured");
405                        this.lastReadTime = System.currentTimeMillis();
406                        this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(() -> {
407                                if (System.currentTimeMillis() - this.lastReadTime > duration) {
408                                        try {
409                                                runnable.run();
410                                        }
411                                        catch (Throwable ex) {
412                                                if (logger.isDebugEnabled()) {
413                                                        logger.debug("ReadInactivityTask failure", ex);
414                                                }
415                                        }
416                                }
417                        }, duration / 2));
418                }
419
420                @Override
421                public void onWriteInactivity(final Runnable runnable, final long duration) {
422                        Assert.state(getTaskScheduler() != null, "No TaskScheduler configured");
423                        this.lastWriteTime = System.currentTimeMillis();
424                        this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(() -> {
425                                if (System.currentTimeMillis() - this.lastWriteTime > duration) {
426                                        try {
427                                                runnable.run();
428                                        }
429                                        catch (Throwable ex) {
430                                                if (logger.isDebugEnabled()) {
431                                                        logger.debug("WriteInactivityTask failure", ex);
432                                                }
433                                        }
434                                }
435                        }, duration / 2));
436                }
437
438                @Override
439                public void close() {
440                        WebSocketSession session = this.session;
441                        if (session != null) {
442                                try {
443                                        session.close();
444                                }
445                                catch (IOException ex) {
446                                        if (logger.isDebugEnabled()) {
447                                                logger.debug("Failed to close session: " + session.getId(), ex);
448                                        }
449                                }
450                        }
451                }
452        }
453
454
455        /**
456         * Encode and decode STOMP WebSocket messages.
457         */
458        private static class StompWebSocketMessageCodec {
459
460                private static final StompEncoder ENCODER = new StompEncoder();
461
462                private static final StompDecoder DECODER = new StompDecoder();
463
464                private final BufferingStompDecoder bufferingDecoder;
465
466                public StompWebSocketMessageCodec(int messageSizeLimit) {
467                        this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit);
468                }
469
470                public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) {
471                        List<Message<byte[]>> result = Collections.emptyList();
472                        ByteBuffer byteBuffer;
473                        if (webSocketMessage instanceof TextMessage) {
474                                byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
475                        }
476                        else if (webSocketMessage instanceof BinaryMessage) {
477                                byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
478                        }
479                        else {
480                                return result;
481                        }
482                        result = this.bufferingDecoder.decode(byteBuffer);
483                        if (result.isEmpty()) {
484                                if (logger.isTraceEnabled()) {
485                                        logger.trace("Incomplete STOMP frame content received, bufferSize=" +
486                                                        this.bufferingDecoder.getBufferSize() + ", bufferSizeLimit=" +
487                                                        this.bufferingDecoder.getBufferSizeLimit() + ".");
488                                }
489                        }
490                        return result;
491                }
492
493                public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
494                        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
495                        Assert.notNull(accessor, "No StompHeaderAccessor available");
496                        byte[] payload = message.getPayload();
497                        byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload);
498
499                        boolean useBinary = (payload.length > 0  &&
500                                        !(SockJsSession.class.isAssignableFrom(sessionType)) &&
501                                        MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType()));
502
503                        return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes));
504                }
505        }
506
507}