001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.net.URI;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.HashSet;
023import java.util.List;
024import java.util.Map;
025import java.util.Set;
026import java.util.concurrent.ConcurrentHashMap;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.context.Lifecycle;
032import org.springframework.http.HttpHeaders;
033import org.springframework.lang.Nullable;
034import org.springframework.scheduling.TaskScheduler;
035import org.springframework.util.Assert;
036import org.springframework.util.ClassUtils;
037import org.springframework.util.CollectionUtils;
038import org.springframework.util.concurrent.ListenableFuture;
039import org.springframework.util.concurrent.SettableListenableFuture;
040import org.springframework.web.socket.WebSocketHandler;
041import org.springframework.web.socket.WebSocketHttpHeaders;
042import org.springframework.web.socket.WebSocketSession;
043import org.springframework.web.socket.client.WebSocketClient;
044import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
045import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
046import org.springframework.web.socket.sockjs.transport.TransportType;
047import org.springframework.web.util.UriComponentsBuilder;
048
049/**
050 * A SockJS implementation of
051 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
052 * with fallback alternatives that simulate a WebSocket interaction through plain
053 * HTTP streaming and long polling techniques..
054 *
055 * <p>Implements {@link Lifecycle} in order to propagate lifecycle events to
056 * the transports it is configured with.
057 *
058 * @author Rossen Stoyanchev
059 * @since 4.1
060 * @see <a href="https://github.com/sockjs/sockjs-client">https://github.com/sockjs/sockjs-client</a>
061 * @see org.springframework.web.socket.sockjs.client.Transport
062 */
063public class SockJsClient implements WebSocketClient, Lifecycle {
064
065        private static final boolean jackson2Present = ClassUtils.isPresent(
066                        "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader());
067
068        private static final Log logger = LogFactory.getLog(SockJsClient.class);
069
070        private static final Set<String> supportedProtocols = new HashSet<>(4);
071
072        static {
073                supportedProtocols.add("ws");
074                supportedProtocols.add("wss");
075                supportedProtocols.add("http");
076                supportedProtocols.add("https");
077        }
078
079
080        private final List<Transport> transports;
081
082        @Nullable
083        private String[] httpHeaderNames;
084
085        private InfoReceiver infoReceiver;
086
087        @Nullable
088        private SockJsMessageCodec messageCodec;
089
090        @Nullable
091        private TaskScheduler connectTimeoutScheduler;
092
093        private volatile boolean running = false;
094
095        private final Map<URI, ServerInfo> serverInfoCache = new ConcurrentHashMap<>();
096
097
098        /**
099         * Create a {@code SockJsClient} with the given transports.
100         * <p>If the list includes an {@link XhrTransport} (or more specifically an
101         * implementation of {@link InfoReceiver}) the instance is used to initialize
102         * the {@link #setInfoReceiver(InfoReceiver) infoReceiver} property, or
103         * otherwise is defaulted to {@link RestTemplateXhrTransport}.
104         * @param transports the (non-empty) list of transports to use
105         */
106        public SockJsClient(List<Transport> transports) {
107                Assert.notEmpty(transports, "No transports provided");
108                this.transports = new ArrayList<>(transports);
109                this.infoReceiver = initInfoReceiver(transports);
110                if (jackson2Present) {
111                        this.messageCodec = new Jackson2SockJsMessageCodec();
112                }
113        }
114
115        private static InfoReceiver initInfoReceiver(List<Transport> transports) {
116                for (Transport transport : transports) {
117                        if (transport instanceof InfoReceiver) {
118                                return ((InfoReceiver) transport);
119                        }
120                }
121                return new RestTemplateXhrTransport();
122        }
123
124
125        /**
126         * The names of HTTP headers that should be copied from the handshake headers
127         * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)}
128         * and also used with other HTTP requests issued as part of that SockJS
129         * connection, e.g. the initial info request, XHR send or receive requests.
130         * <p>By default if this property is not set, all handshake headers are also
131         * used for other HTTP requests. Set it if you want only a subset of handshake
132         * headers (e.g. auth headers) to be used for other HTTP requests.
133         * @param httpHeaderNames the HTTP header names
134         */
135        public void setHttpHeaderNames(@Nullable String... httpHeaderNames) {
136                this.httpHeaderNames = httpHeaderNames;
137        }
138
139        /**
140         * The configured HTTP header names to be copied from the handshake
141         * headers and also included in other HTTP requests.
142         */
143        @Nullable
144        public String[] getHttpHeaderNames() {
145                return this.httpHeaderNames;
146        }
147
148        /**
149         * Configure the {@code InfoReceiver} to use to perform the SockJS "Info"
150         * request before the SockJS session starts.
151         * <p>If the list of transports provided to the constructor contained an
152         * {@link XhrTransport} or an implementation of {@link InfoReceiver} that
153         * instance would have been used to initialize this property, or otherwise
154         * it defaults to {@link RestTemplateXhrTransport}.
155         * @param infoReceiver the transport to use for the SockJS "Info" request
156         */
157        public void setInfoReceiver(InfoReceiver infoReceiver) {
158                Assert.notNull(infoReceiver, "InfoReceiver is required");
159                this.infoReceiver = infoReceiver;
160        }
161
162        /**
163         * Return the configured {@code InfoReceiver} (never {@code null}).
164         */
165        public InfoReceiver getInfoReceiver() {
166                return this.infoReceiver;
167        }
168
169        /**
170         * Set the SockJsMessageCodec to use.
171         * <p>By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec
172         * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath.
173         */
174        public void setMessageCodec(SockJsMessageCodec messageCodec) {
175                Assert.notNull(messageCodec, "SockJsMessageCodec is required");
176                this.messageCodec = messageCodec;
177        }
178
179        /**
180         * Return the SockJsMessageCodec to use.
181         */
182        public SockJsMessageCodec getMessageCodec() {
183                Assert.state(this.messageCodec != null, "No SockJsMessageCodec set");
184                return this.messageCodec;
185        }
186
187        /**
188         * Configure a {@code TaskScheduler} for scheduling a connect timeout task
189         * where the timeout value is calculated based on the duration of the initial
190         * SockJS "Info" request. The connect timeout task ensures a more timely
191         * fallback but is otherwise entirely optional.
192         * <p>By default this is not configured in which case a fallback may take longer.
193         * @param connectTimeoutScheduler the task scheduler to use
194         */
195        public void setConnectTimeoutScheduler(TaskScheduler connectTimeoutScheduler) {
196                this.connectTimeoutScheduler = connectTimeoutScheduler;
197        }
198
199
200        @Override
201        public void start() {
202                if (!isRunning()) {
203                        this.running = true;
204                        for (Transport transport : this.transports) {
205                                if (transport instanceof Lifecycle) {
206                                        Lifecycle lifecycle = (Lifecycle) transport;
207                                        if (!lifecycle.isRunning()) {
208                                                lifecycle.start();
209                                        }
210                                }
211                        }
212                }
213        }
214
215        @Override
216        public void stop() {
217                if (isRunning()) {
218                        this.running = false;
219                        for (Transport transport : this.transports) {
220                                if (transport instanceof Lifecycle) {
221                                        Lifecycle lifecycle = (Lifecycle) transport;
222                                        if (lifecycle.isRunning()) {
223                                                lifecycle.stop();
224                                        }
225                                }
226                        }
227                }
228        }
229
230        @Override
231        public boolean isRunning() {
232                return this.running;
233        }
234
235
236        @Override
237        public ListenableFuture<WebSocketSession> doHandshake(
238                        WebSocketHandler handler, String uriTemplate, Object... uriVars) {
239
240                Assert.notNull(uriTemplate, "uriTemplate must not be null");
241                URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
242                return doHandshake(handler, null, uri);
243        }
244
245        @Override
246        public final ListenableFuture<WebSocketSession> doHandshake(
247                        WebSocketHandler handler, @Nullable WebSocketHttpHeaders headers, URI url) {
248
249                Assert.notNull(handler, "WebSocketHandler is required");
250                Assert.notNull(url, "URL is required");
251
252                String scheme = url.getScheme();
253                if (!supportedProtocols.contains(scheme)) {
254                        throw new IllegalArgumentException("Invalid scheme: '" + scheme + "'");
255                }
256
257                SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<>();
258                try {
259                        SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url);
260                        ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers));
261                        createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture);
262                }
263                catch (Exception exception) {
264                        if (logger.isErrorEnabled()) {
265                                logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception);
266                        }
267                        connectFuture.setException(exception);
268                }
269                return connectFuture;
270        }
271
272        @Nullable
273        private HttpHeaders getHttpRequestHeaders(@Nullable HttpHeaders webSocketHttpHeaders) {
274                if (getHttpHeaderNames() == null || webSocketHttpHeaders == null) {
275                        return webSocketHttpHeaders;
276                }
277                else {
278                        HttpHeaders httpHeaders = new HttpHeaders();
279                        for (String name : getHttpHeaderNames()) {
280                                List<String> values = webSocketHttpHeaders.get(name);
281                                if (values != null) {
282                                        httpHeaders.put(name, values);
283                                }
284                        }
285                        return httpHeaders;
286                }
287        }
288
289        private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, @Nullable HttpHeaders headers) {
290                URI infoUrl = sockJsUrlInfo.getInfoUrl();
291                ServerInfo info = this.serverInfoCache.get(infoUrl);
292                if (info == null) {
293                        long start = System.currentTimeMillis();
294                        String response = this.infoReceiver.executeInfoRequest(infoUrl, headers);
295                        long infoRequestTime = System.currentTimeMillis() - start;
296                        info = new ServerInfo(response, infoRequestTime);
297                        this.serverInfoCache.put(infoUrl, info);
298                }
299                return info;
300        }
301
302        private DefaultTransportRequest createRequest(
303                        SockJsUrlInfo urlInfo, @Nullable HttpHeaders headers, ServerInfo serverInfo) {
304
305                List<DefaultTransportRequest> requests = new ArrayList<>(this.transports.size());
306                for (Transport transport : this.transports) {
307                        for (TransportType type : transport.getTransportTypes()) {
308                                if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) {
309                                        requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers),
310                                                        transport, type, getMessageCodec()));
311                                }
312                        }
313                }
314                if (CollectionUtils.isEmpty(requests)) {
315                        throw new IllegalStateException(
316                                        "No transports: " + urlInfo + ", webSocketEnabled=" + serverInfo.isWebSocketEnabled());
317                }
318                for (int i = 0; i < requests.size() - 1; i++) {
319                        DefaultTransportRequest request = requests.get(i);
320                        Principal user = getUser();
321                        if (user != null) {
322                                request.setUser(user);
323                        }
324                        if (this.connectTimeoutScheduler != null) {
325                                request.setTimeoutValue(serverInfo.getRetransmissionTimeout());
326                                request.setTimeoutScheduler(this.connectTimeoutScheduler);
327                        }
328                        request.setFallbackRequest(requests.get(i + 1));
329                }
330                return requests.get(0);
331        }
332
333        /**
334         * Return the user to associate with the SockJS session and make available via
335         * {@link org.springframework.web.socket.WebSocketSession#getPrincipal()}.
336         * <p>By default this method returns {@code null}.
337         * @return the user to associate with the session (possibly {@code null})
338         */
339        @Nullable
340        protected Principal getUser() {
341                return null;
342        }
343
344        /**
345         * By default the result of a SockJS "Info" request, including whether the
346         * server has WebSocket disabled and how long the request took (used for
347         * calculating transport timeout time) is cached. This method can be used to
348         * clear that cache hence causing it to re-populate.
349         */
350        public void clearServerInfoCache() {
351                this.serverInfoCache.clear();
352        }
353
354
355        /**
356         * A simple value object holding the result from a SockJS "Info" request.
357         */
358        private static class ServerInfo {
359
360                private final boolean webSocketEnabled;
361
362                private final long responseTime;
363
364                public ServerInfo(String response, long responseTime) {
365                        this.responseTime = responseTime;
366                        this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*");
367                }
368
369                public boolean isWebSocketEnabled() {
370                        return this.webSocketEnabled;
371                }
372
373                public long getRetransmissionTimeout() {
374                        return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300);
375                }
376        }
377
378}