001/*
002 * Copyright 2002-2014 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.net.URI;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.HashSet;
023import java.util.List;
024import java.util.Map;
025import java.util.Set;
026import java.util.concurrent.ConcurrentHashMap;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.context.Lifecycle;
032import org.springframework.http.HttpHeaders;
033import org.springframework.scheduling.TaskScheduler;
034import org.springframework.util.Assert;
035import org.springframework.util.ClassUtils;
036import org.springframework.util.CollectionUtils;
037import org.springframework.util.concurrent.ListenableFuture;
038import org.springframework.util.concurrent.SettableListenableFuture;
039import org.springframework.web.socket.WebSocketHandler;
040import org.springframework.web.socket.WebSocketHttpHeaders;
041import org.springframework.web.socket.WebSocketSession;
042import org.springframework.web.socket.client.WebSocketClient;
043import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
044import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
045import org.springframework.web.socket.sockjs.transport.TransportType;
046import org.springframework.web.util.UriComponentsBuilder;
047
048/**
049 * A SockJS implementation of
050 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}
051 * with fallback alternatives that simulate a WebSocket interaction through plain
052 * HTTP streaming and long polling techniques..
053 *
054 * <p>Implements {@link Lifecycle} in order to propagate lifecycle events to
055 * the transports it is configured with.
056 *
057 * @author Rossen Stoyanchev
058 * @since 4.1
059 * @see <a href="https://github.com/sockjs/sockjs-client">https://github.com/sockjs/sockjs-client</a>
060 * @see org.springframework.web.socket.sockjs.client.Transport
061 */
062public class SockJsClient implements WebSocketClient, Lifecycle {
063
064        private static final boolean jackson2Present = ClassUtils.isPresent(
065                        "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader());
066
067        private static final Log logger = LogFactory.getLog(SockJsClient.class);
068
069        private static final Set<String> supportedProtocols = new HashSet<String>(4);
070
071        static {
072                supportedProtocols.add("ws");
073                supportedProtocols.add("wss");
074                supportedProtocols.add("http");
075                supportedProtocols.add("https");
076        }
077
078
079        private final List<Transport> transports;
080
081        private String[] httpHeaderNames;
082
083        private InfoReceiver infoReceiver;
084
085        private SockJsMessageCodec messageCodec;
086
087        private TaskScheduler connectTimeoutScheduler;
088
089        private volatile boolean running = false;
090
091        private final Map<URI, ServerInfo> serverInfoCache = new ConcurrentHashMap<URI, ServerInfo>();
092
093
094        /**
095         * Create a {@code SockJsClient} with the given transports.
096         * <p>If the list includes an {@link XhrTransport} (or more specifically an
097         * implementation of {@link InfoReceiver}) the instance is used to initialize
098         * the {@link #setInfoReceiver(InfoReceiver) infoReceiver} property, or
099         * otherwise is defaulted to {@link RestTemplateXhrTransport}.
100         * @param transports the (non-empty) list of transports to use
101         */
102        public SockJsClient(List<Transport> transports) {
103                Assert.notEmpty(transports, "No transports provided");
104                this.transports = new ArrayList<Transport>(transports);
105                this.infoReceiver = initInfoReceiver(transports);
106                if (jackson2Present) {
107                        this.messageCodec = new Jackson2SockJsMessageCodec();
108                }
109        }
110
111        private static InfoReceiver initInfoReceiver(List<Transport> transports) {
112                for (Transport transport : transports) {
113                        if (transport instanceof InfoReceiver) {
114                                return ((InfoReceiver) transport);
115                        }
116                }
117                return new RestTemplateXhrTransport();
118        }
119
120
121        /**
122         * The names of HTTP headers that should be copied from the handshake headers
123         * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)}
124         * and also used with other HTTP requests issued as part of that SockJS
125         * connection, e.g. the initial info request, XHR send or receive requests.
126         *
127         * <p>By default if this property is not set, all handshake headers are also
128         * used for other HTTP requests. Set it if you want only a subset of handshake
129         * headers (e.g. auth headers) to be used for other HTTP requests.
130         *
131         * @param httpHeaderNames HTTP header names
132         */
133        public void setHttpHeaderNames(String... httpHeaderNames) {
134                this.httpHeaderNames = httpHeaderNames;
135        }
136
137        /**
138         * The configured HTTP header names to be copied from the handshake
139         * headers and also included in other HTTP requests.
140         */
141        public String[] getHttpHeaderNames() {
142                return this.httpHeaderNames;
143        }
144
145        /**
146         * Configure the {@code InfoReceiver} to use to perform the SockJS "Info"
147         * request before the SockJS session starts.
148         * <p>If the list of transports provided to the constructor contained an
149         * {@link XhrTransport} or an implementation of {@link InfoReceiver} that
150         * instance would have been used to initialize this property, or otherwise
151         * it defaults to {@link RestTemplateXhrTransport}.
152         * @param infoReceiver the transport to use for the SockJS "Info" request
153         */
154        public void setInfoReceiver(InfoReceiver infoReceiver) {
155                Assert.notNull(infoReceiver, "InfoReceiver is required");
156                this.infoReceiver = infoReceiver;
157        }
158
159        /**
160         * Return the configured {@code InfoReceiver} (never {@code null}).
161         */
162        public InfoReceiver getInfoReceiver() {
163                return this.infoReceiver;
164        }
165
166        /**
167         * Set the SockJsMessageCodec to use.
168         * <p>By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec
169         * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath.
170         */
171        public void setMessageCodec(SockJsMessageCodec messageCodec) {
172                Assert.notNull(messageCodec, "'messageCodec' is required");
173                this.messageCodec = messageCodec;
174        }
175
176        /**
177         * Return the SockJsMessageCodec to use.
178         */
179        public SockJsMessageCodec getMessageCodec() {
180                return this.messageCodec;
181        }
182
183        /**
184         * Configure a {@code TaskScheduler} for scheduling a connect timeout task
185         * where the timeout value is calculated based on the duration of the initial
186         * SockJS "Info" request. The connect timeout task ensures a more timely
187         * fallback but is otherwise entirely optional.
188         * <p>By default this is not configured in which case a fallback may take longer.
189         * @param connectTimeoutScheduler the task scheduler to use
190         */
191        public void setConnectTimeoutScheduler(TaskScheduler connectTimeoutScheduler) {
192                this.connectTimeoutScheduler = connectTimeoutScheduler;
193        }
194
195
196        @Override
197        public void start() {
198                if (!isRunning()) {
199                        this.running = true;
200                        for (Transport transport : this.transports) {
201                                if (transport instanceof Lifecycle) {
202                                        if (!((Lifecycle) transport).isRunning()) {
203                                                ((Lifecycle) transport).start();
204                                        }
205                                }
206                        }
207                }
208        }
209
210        @Override
211        public void stop() {
212                if (isRunning()) {
213                        this.running = false;
214                        for (Transport transport : this.transports) {
215                                if (transport instanceof Lifecycle) {
216                                        if (((Lifecycle) transport).isRunning()) {
217                                                ((Lifecycle) transport).stop();
218                                        }
219                                }
220                        }
221                }
222        }
223
224        @Override
225        public boolean isRunning() {
226                return this.running;
227        }
228
229
230        @Override
231        public ListenableFuture<WebSocketSession> doHandshake(
232                        WebSocketHandler handler, String uriTemplate, Object... uriVars) {
233
234                Assert.notNull(uriTemplate, "uriTemplate must not be null");
235                URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
236                return doHandshake(handler, null, uri);
237        }
238
239        @Override
240        public final ListenableFuture<WebSocketSession> doHandshake(
241                        WebSocketHandler handler, WebSocketHttpHeaders headers, URI url) {
242
243                Assert.notNull(handler, "WebSocketHandler is required");
244                Assert.notNull(url, "URL is required");
245
246                String scheme = url.getScheme();
247                if (!supportedProtocols.contains(scheme)) {
248                        throw new IllegalArgumentException("Invalid scheme: '" + scheme + "'");
249                }
250
251                SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
252                try {
253                        SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url);
254                        ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers));
255                        createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture);
256                }
257                catch (Throwable exception) {
258                        if (logger.isErrorEnabled()) {
259                                logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception);
260                        }
261                        connectFuture.setException(exception);
262                }
263                return connectFuture;
264        }
265
266        private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) {
267                if (getHttpHeaderNames() == null) {
268                        return webSocketHttpHeaders;
269                }
270                else {
271                        HttpHeaders httpHeaders = new HttpHeaders();
272                        for (String name : getHttpHeaderNames()) {
273                                if (webSocketHttpHeaders.containsKey(name)) {
274                                        httpHeaders.put(name, webSocketHttpHeaders.get(name));
275                                }
276                        }
277                        return httpHeaders;
278                }
279        }
280
281        private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) {
282                URI infoUrl = sockJsUrlInfo.getInfoUrl();
283                ServerInfo info = this.serverInfoCache.get(infoUrl);
284                if (info == null) {
285                        long start = System.currentTimeMillis();
286                        String response = this.infoReceiver.executeInfoRequest(infoUrl, headers);
287                        long infoRequestTime = System.currentTimeMillis() - start;
288                        info = new ServerInfo(response, infoRequestTime);
289                        this.serverInfoCache.put(infoUrl, info);
290                }
291                return info;
292        }
293
294        private DefaultTransportRequest createRequest(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) {
295                List<DefaultTransportRequest> requests = new ArrayList<DefaultTransportRequest>(this.transports.size());
296                for (Transport transport : this.transports) {
297                        for (TransportType type : transport.getTransportTypes()) {
298                                if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) {
299                                        requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers),
300                                                        transport, type, getMessageCodec()));
301                                }
302                        }
303                }
304                if (CollectionUtils.isEmpty(requests)) {
305                        throw new IllegalStateException(
306                                        "No transports: " + urlInfo + ", webSocketEnabled=" + serverInfo.isWebSocketEnabled());
307                }
308                for (int i = 0; i < requests.size() - 1; i++) {
309                        DefaultTransportRequest request = requests.get(i);
310                        request.setUser(getUser());
311                        if (this.connectTimeoutScheduler != null) {
312                                request.setTimeoutValue(serverInfo.getRetransmissionTimeout());
313                                request.setTimeoutScheduler(this.connectTimeoutScheduler);
314                        }
315                        request.setFallbackRequest(requests.get(i + 1));
316                }
317                return requests.get(0);
318        }
319
320        /**
321         * Return the user to associate with the SockJS session and make available via
322         * {@link org.springframework.web.socket.WebSocketSession#getPrincipal()}.
323         * <p>By default this method returns {@code null}.
324         * @return the user to associate with the session (possibly {@code null})
325         */
326        protected Principal getUser() {
327                return null;
328        }
329
330        /**
331         * By default the result of a SockJS "Info" request, including whether the
332         * server has WebSocket disabled and how long the request took (used for
333         * calculating transport timeout time) is cached. This method can be used to
334         * clear that cache hence causing it to re-populate.
335         */
336        public void clearServerInfoCache() {
337                this.serverInfoCache.clear();
338        }
339
340
341        /**
342         * A simple value object holding the result from a SockJS "Info" request.
343         */
344        private static class ServerInfo {
345
346                private final boolean webSocketEnabled;
347
348                private final long responseTime;
349
350                public ServerInfo(String response, long responseTime) {
351                        this.responseTime = responseTime;
352                        this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*");
353                }
354
355                public boolean isWebSocketEnabled() {
356                        return this.webSocketEnabled;
357                }
358
359                public long getRetransmissionTimeout() {
360                        return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300);
361                }
362        }
363
364}