001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.io.IOException;
020import java.net.URI;
021import java.nio.ByteBuffer;
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.List;
025import java.util.concurrent.ScheduledFuture;
026
027import org.apache.commons.logging.Log;
028import org.apache.commons.logging.LogFactory;
029
030import org.springframework.context.Lifecycle;
031import org.springframework.context.SmartLifecycle;
032import org.springframework.messaging.Message;
033import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
034import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
035import org.springframework.messaging.simp.stomp.StompClientSupport;
036import org.springframework.messaging.simp.stomp.StompDecoder;
037import org.springframework.messaging.simp.stomp.StompEncoder;
038import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
039import org.springframework.messaging.simp.stomp.StompHeaders;
040import org.springframework.messaging.simp.stomp.StompSession;
041import org.springframework.messaging.simp.stomp.StompSessionHandler;
042import org.springframework.messaging.support.MessageHeaderAccessor;
043import org.springframework.messaging.tcp.TcpConnection;
044import org.springframework.messaging.tcp.TcpConnectionHandler;
045import org.springframework.scheduling.TaskScheduler;
046import org.springframework.util.Assert;
047import org.springframework.util.MimeTypeUtils;
048import org.springframework.util.concurrent.ListenableFuture;
049import org.springframework.util.concurrent.ListenableFutureCallback;
050import org.springframework.util.concurrent.SettableListenableFuture;
051import org.springframework.web.socket.BinaryMessage;
052import org.springframework.web.socket.CloseStatus;
053import org.springframework.web.socket.TextMessage;
054import org.springframework.web.socket.WebSocketHandler;
055import org.springframework.web.socket.WebSocketHttpHeaders;
056import org.springframework.web.socket.WebSocketMessage;
057import org.springframework.web.socket.WebSocketSession;
058import org.springframework.web.socket.client.WebSocketClient;
059import org.springframework.web.socket.sockjs.transport.SockJsSession;
060import org.springframework.web.util.UriComponentsBuilder;
061
062/**
063 * A STOMP over WebSocket client that connects using an implementation of
064 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
065 * including {@link org.springframework.web.socket.sockjs.client.SockJsClient
066 * SockJsClient}.
067 *
068 * @author Rossen Stoyanchev
069 * @since 4.2
070 */
071public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle {
072
073        private static final Log logger = LogFactory.getLog(WebSocketStompClient.class);
074
075        private final WebSocketClient webSocketClient;
076
077        private int inboundMessageSizeLimit = 64 * 1024;
078
079        private boolean autoStartup = true;
080
081        private int phase = Integer.MAX_VALUE;
082
083        private volatile boolean running = false;
084
085
086        /**
087         * Class constructor. Sets {@link #setDefaultHeartbeat} to "0,0" but will
088         * reset it back to the preferred "10000,10000" when a
089         * {@link #setTaskScheduler} is configured.
090         * @param webSocketClient the WebSocket client to connect with
091         */
092        public WebSocketStompClient(WebSocketClient webSocketClient) {
093                Assert.notNull(webSocketClient, "WebSocketClient is required");
094                this.webSocketClient = webSocketClient;
095                setDefaultHeartbeat(new long[] {0, 0});
096        }
097
098
099        /**
100         * Return the configured WebSocketClient.
101         */
102        public WebSocketClient getWebSocketClient() {
103                return this.webSocketClient;
104        }
105
106        /**
107         * {@inheritDoc}
108         * <p>Also automatically sets the {@link #setDefaultHeartbeat defaultHeartbeat}
109         * property to "10000,10000" if it is currently set to "0,0".
110         */
111        @Override
112        public void setTaskScheduler(TaskScheduler taskScheduler) {
113                if (taskScheduler != null && !isDefaultHeartbeatEnabled()) {
114                        setDefaultHeartbeat(new long[] {10000, 10000});
115                }
116                super.setTaskScheduler(taskScheduler);
117        }
118
119        /**
120         * Configure the maximum size allowed for inbound STOMP message.
121         * Since a STOMP message can be received in multiple WebSocket messages,
122         * buffering may be required and this property determines the maximum buffer
123         * size per message.
124         * <p>By default this is set to 64 * 1024 (64K).
125         */
126        public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) {
127                this.inboundMessageSizeLimit = inboundMessageSizeLimit;
128        }
129
130        /**
131         * Get the configured inbound message buffer size in bytes.
132         */
133        public int getInboundMessageSizeLimit() {
134                return this.inboundMessageSizeLimit;
135        }
136
137        /**
138         * Set whether to auto-start the contained WebSocketClient when the Spring
139         * context has been refreshed.
140         * <p>Default is "true".
141         */
142        public void setAutoStartup(boolean autoStartup) {
143                this.autoStartup = autoStartup;
144        }
145
146        /**
147         * Return the value for the 'autoStartup' property. If "true", this client
148         * will automatically start and stop the contained WebSocketClient.
149         */
150        @Override
151        public boolean isAutoStartup() {
152                return this.autoStartup;
153        }
154
155        /**
156         * Specify the phase in which the WebSocket client should be started and
157         * subsequently closed. The startup order proceeds from lowest to highest,
158         * and the shutdown order is the reverse of that.
159         * <p>By default this is Integer.MAX_VALUE meaning that the WebSocket client
160         * is started as late as possible and stopped as soon as possible.
161         */
162        public void setPhase(int phase) {
163                this.phase = phase;
164        }
165
166        /**
167         * Return the configured phase.
168         */
169        @Override
170        public int getPhase() {
171                return this.phase;
172        }
173
174
175        @Override
176        public void start() {
177                if (!isRunning()) {
178                        this.running = true;
179                        if (getWebSocketClient() instanceof Lifecycle) {
180                                ((Lifecycle) getWebSocketClient()).start();
181                        }
182                }
183
184        }
185
186        @Override
187        public void stop() {
188                if (isRunning()) {
189                        this.running = false;
190                        if (getWebSocketClient() instanceof Lifecycle) {
191                                ((Lifecycle) getWebSocketClient()).stop();
192                        }
193                }
194        }
195
196        @Override
197        public void stop(Runnable callback) {
198                stop();
199                callback.run();
200        }
201
202        @Override
203        public boolean isRunning() {
204                return this.running;
205        }
206
207
208        /**
209         * Connect to the given WebSocket URL and notify the given
210         * {@link org.springframework.messaging.simp.stomp.StompSessionHandler}
211         * when connected on the STOMP level after the CONNECTED frame is received.
212         * @param url the url to connect to
213         * @param handler the session handler
214         * @param uriVars the URI variables to expand into the URL
215         * @return a ListenableFuture for access to the session when ready for use
216         */
217        public ListenableFuture<StompSession> connect(String url, StompSessionHandler handler, Object... uriVars) {
218                return connect(url, null, handler, uriVars);
219        }
220
221        /**
222         * An overloaded version of
223         * {@link #connect(String, StompSessionHandler, Object...)} that also
224         * accepts {@link WebSocketHttpHeaders} to use for the WebSocket handshake.
225         * @param url the url to connect to
226         * @param handshakeHeaders the headers for the WebSocket handshake
227         * @param handler the session handler
228         * @param uriVariables the URI variables to expand into the URL
229         * @return a ListenableFuture for access to the session when ready for use
230         */
231        public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders,
232                        StompSessionHandler handler, Object... uriVariables) {
233
234                return connect(url, handshakeHeaders, null, handler, uriVariables);
235        }
236
237        /**
238         * An overloaded version of
239         * {@link #connect(String, StompSessionHandler, Object...)} that also accepts
240         * {@link WebSocketHttpHeaders} to use for the WebSocket handshake and
241         * {@link StompHeaders} for the STOMP CONNECT frame.
242         * @param url the url to connect to
243         * @param handshakeHeaders headers for the WebSocket handshake
244         * @param connectHeaders headers for the STOMP CONNECT frame
245         * @param handler the session handler
246         * @param uriVariables the URI variables to expand into the URL
247         * @return a ListenableFuture for access to the session when ready for use
248         */
249        public ListenableFuture<StompSession> connect(String url, WebSocketHttpHeaders handshakeHeaders,
250                        StompHeaders connectHeaders, StompSessionHandler handler, Object... uriVariables) {
251
252                Assert.notNull(url, "'url' must not be null");
253                URI uri = UriComponentsBuilder.fromUriString(url).buildAndExpand(uriVariables).encode().toUri();
254                return connect(uri, handshakeHeaders, connectHeaders, handler);
255        }
256
257        /**
258         * An overloaded version of
259         * {@link #connect(String, WebSocketHttpHeaders, StompSessionHandler, Object...)}
260         * that accepts a fully prepared {@link java.net.URI}.
261         * @param url the url to connect to
262         * @param handshakeHeaders the headers for the WebSocket handshake
263         * @param connectHeaders headers for the STOMP CONNECT frame
264         * @param sessionHandler the STOMP session handler
265         * @return a ListenableFuture for access to the session when ready for use
266         */
267        public ListenableFuture<StompSession> connect(URI url, WebSocketHttpHeaders handshakeHeaders,
268                        StompHeaders connectHeaders, StompSessionHandler sessionHandler) {
269
270                Assert.notNull(url, "'url' must not be null");
271                ConnectionHandlingStompSession session = createSession(connectHeaders, sessionHandler);
272                WebSocketTcpConnectionHandlerAdapter adapter = new WebSocketTcpConnectionHandlerAdapter(session);
273                getWebSocketClient().doHandshake(adapter, handshakeHeaders, url).addCallback(adapter);
274                return session.getSessionFuture();
275        }
276
277        @Override
278        protected StompHeaders processConnectHeaders(StompHeaders connectHeaders) {
279                connectHeaders = super.processConnectHeaders(connectHeaders);
280                if (connectHeaders.isHeartbeatEnabled()) {
281                        Assert.state(getTaskScheduler() != null, "TaskScheduler must be set if heartbeats are enabled");
282                }
283                return connectHeaders;
284        }
285
286
287        /**
288         * Adapt WebSocket to the TcpConnectionHandler and TcpConnection contracts.
289         */
290        private class WebSocketTcpConnectionHandlerAdapter implements ListenableFutureCallback<WebSocketSession>,
291                        WebSocketHandler, TcpConnection<byte[]> {
292
293                private final TcpConnectionHandler<byte[]> connectionHandler;
294
295                private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit());
296
297                private volatile WebSocketSession session;
298
299                private volatile long lastReadTime = -1;
300
301                private volatile long lastWriteTime = -1;
302
303                private final List<ScheduledFuture<?>> inactivityTasks = new ArrayList<ScheduledFuture<?>>(2);
304
305                public WebSocketTcpConnectionHandlerAdapter(TcpConnectionHandler<byte[]> connectionHandler) {
306                        Assert.notNull(connectionHandler, "TcpConnectionHandler must not be null");
307                        this.connectionHandler = connectionHandler;
308                }
309
310                // ListenableFutureCallback implementation: handshake outcome
311
312                @Override
313                public void onSuccess(WebSocketSession webSocketSession) {
314                }
315
316                @Override
317                public void onFailure(Throwable ex) {
318                        this.connectionHandler.afterConnectFailure(ex);
319                }
320
321                // WebSocketHandler implementation
322
323                @Override
324                public void afterConnectionEstablished(WebSocketSession session) {
325                        this.session = session;
326                        this.connectionHandler.afterConnected(this);
327                }
328
329                @Override
330                public void handleMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) {
331                        this.lastReadTime = (this.lastReadTime != -1 ? System.currentTimeMillis() : -1);
332                        List<Message<byte[]>> messages;
333                        try {
334                                messages = this.codec.decode(webSocketMessage);
335                        }
336                        catch (Throwable ex) {
337                                this.connectionHandler.handleFailure(ex);
338                                return;
339                        }
340                        for (Message<byte[]> message : messages) {
341                                this.connectionHandler.handleMessage(message);
342                        }
343                }
344
345                @Override
346                public void handleTransportError(WebSocketSession session, Throwable ex) throws Exception {
347                        this.connectionHandler.handleFailure(ex);
348                }
349
350                @Override
351                public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
352                        cancelInactivityTasks();
353                        this.connectionHandler.afterConnectionClosed();
354                }
355
356                private void cancelInactivityTasks() {
357                        for (ScheduledFuture<?> task : this.inactivityTasks) {
358                                try {
359                                        task.cancel(true);
360                                }
361                                catch (Throwable ex) {
362                                        // Ignore
363                                }
364                        }
365                        this.lastReadTime = -1;
366                        this.lastWriteTime = -1;
367                        this.inactivityTasks.clear();
368                }
369
370                @Override
371                public boolean supportsPartialMessages() {
372                        return false;
373                }
374
375                // TcpConnection implementation
376
377                @Override
378                public ListenableFuture<Void> send(Message<byte[]> message) {
379                        updateLastWriteTime();
380                        SettableListenableFuture<Void> future = new SettableListenableFuture<Void>();
381                        try {
382                                this.session.sendMessage(this.codec.encode(message, this.session.getClass()));
383                                future.set(null);
384                        }
385                        catch (Throwable ex) {
386                                future.setException(ex);
387                        }
388                        finally {
389                                updateLastWriteTime();
390                        }
391                        return future;
392                }
393
394                private void updateLastWriteTime() {
395                        long lastWriteTime = this.lastWriteTime;
396                        if (lastWriteTime != -1) {
397                                this.lastWriteTime = System.currentTimeMillis();
398                        }
399                }
400
401                @Override
402                public void onReadInactivity(final Runnable runnable, final long duration) {
403                        Assert.state(getTaskScheduler() != null, "No TaskScheduler configured");
404                        this.lastReadTime = System.currentTimeMillis();
405                        this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() {
406                                @Override
407                                public void run() {
408                                        if (System.currentTimeMillis() - lastReadTime > duration) {
409                                                try {
410                                                        runnable.run();
411                                                }
412                                                catch (Throwable ex) {
413                                                        if (logger.isDebugEnabled()) {
414                                                                logger.debug("ReadInactivityTask failure", ex);
415                                                        }
416                                                }
417                                        }
418                                }
419                        }, duration / 2));
420                }
421
422                @Override
423                public void onWriteInactivity(final Runnable runnable, final long duration) {
424                        Assert.state(getTaskScheduler() != null, "No TaskScheduler configured");
425                        this.lastWriteTime = System.currentTimeMillis();
426                        this.inactivityTasks.add(getTaskScheduler().scheduleWithFixedDelay(new Runnable() {
427                                @Override
428                                public void run() {
429                                        if (System.currentTimeMillis() - lastWriteTime > duration) {
430                                                try {
431                                                        runnable.run();
432                                                }
433                                                catch (Throwable ex) {
434                                                        if (logger.isDebugEnabled()) {
435                                                                logger.debug("WriteInactivityTask failure", ex);
436                                                        }
437                                                }
438                                        }
439                                }
440                        }, duration / 2));
441                }
442
443                @Override
444                public void close() {
445                        try {
446                                this.session.close();
447                        }
448                        catch (IOException ex) {
449                                if (logger.isDebugEnabled()) {
450                                        logger.debug("Failed to close session: " + this.session.getId(), ex);
451                                }
452                        }
453                }
454        }
455
456
457        /**
458         * Encode and decode STOMP WebSocket messages.
459         */
460        private static class StompWebSocketMessageCodec {
461
462                private static final StompEncoder ENCODER = new StompEncoder();
463
464                private static final StompDecoder DECODER = new StompDecoder();
465
466                private final BufferingStompDecoder bufferingDecoder;
467
468                public StompWebSocketMessageCodec(int messageSizeLimit) {
469                        this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit);
470                }
471
472                public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) {
473                        List<Message<byte[]>> result = Collections.<Message<byte[]>>emptyList();
474                        ByteBuffer byteBuffer;
475                        if (webSocketMessage instanceof TextMessage) {
476                                byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
477                        }
478                        else if (webSocketMessage instanceof BinaryMessage) {
479                                byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
480                        }
481                        else {
482                                return result;
483                        }
484                        result = this.bufferingDecoder.decode(byteBuffer);
485                        if (result.isEmpty()) {
486                                if (logger.isTraceEnabled()) {
487                                        logger.trace("Incomplete STOMP frame content received, bufferSize=" +
488                                                        this.bufferingDecoder.getBufferSize() + ", bufferSizeLimit=" +
489                                                        this.bufferingDecoder.getBufferSizeLimit() + ".");
490                                }
491                        }
492                        return result;
493                }
494
495                public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
496                        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
497                        Assert.notNull(accessor, "No StompHeaderAccessor available");
498                        byte[] payload = message.getPayload();
499                        byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload);
500
501                        boolean useBinary = (payload.length > 0  &&
502                                        !(SockJsSession.class.isAssignableFrom(sessionType)) &&
503                                        MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType()));
504
505                        return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes));
506                }
507        }
508
509}