001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.client; 018 019import java.net.URI; 020import java.security.Principal; 021import java.util.ArrayList; 022import java.util.HashSet; 023import java.util.List; 024import java.util.Map; 025import java.util.Set; 026import java.util.concurrent.ConcurrentHashMap; 027 028import org.apache.commons.logging.Log; 029import org.apache.commons.logging.LogFactory; 030 031import org.springframework.context.Lifecycle; 032import org.springframework.http.HttpHeaders; 033import org.springframework.lang.Nullable; 034import org.springframework.scheduling.TaskScheduler; 035import org.springframework.util.Assert; 036import org.springframework.util.ClassUtils; 037import org.springframework.util.CollectionUtils; 038import org.springframework.util.concurrent.ListenableFuture; 039import org.springframework.util.concurrent.SettableListenableFuture; 040import org.springframework.web.socket.WebSocketHandler; 041import org.springframework.web.socket.WebSocketHttpHeaders; 042import org.springframework.web.socket.WebSocketSession; 043import org.springframework.web.socket.client.WebSocketClient; 044import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; 045import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; 046import org.springframework.web.socket.sockjs.transport.TransportType; 047import org.springframework.web.util.UriComponentsBuilder; 048 049/** 050 * A SockJS implementation of 051 * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} 052 * with fallback alternatives that simulate a WebSocket interaction through plain 053 * HTTP streaming and long polling techniques.. 054 * 055 * <p>Implements {@link Lifecycle} in order to propagate lifecycle events to 056 * the transports it is configured with. 057 * 058 * @author Rossen Stoyanchev 059 * @since 4.1 060 * @see <a href="https://github.com/sockjs/sockjs-client">https://github.com/sockjs/sockjs-client</a> 061 * @see org.springframework.web.socket.sockjs.client.Transport 062 */ 063public class SockJsClient implements WebSocketClient, Lifecycle { 064 065 private static final boolean jackson2Present = ClassUtils.isPresent( 066 "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader()); 067 068 private static final Log logger = LogFactory.getLog(SockJsClient.class); 069 070 private static final Set<String> supportedProtocols = new HashSet<>(4); 071 072 static { 073 supportedProtocols.add("ws"); 074 supportedProtocols.add("wss"); 075 supportedProtocols.add("http"); 076 supportedProtocols.add("https"); 077 } 078 079 080 private final List<Transport> transports; 081 082 @Nullable 083 private String[] httpHeaderNames; 084 085 private InfoReceiver infoReceiver; 086 087 @Nullable 088 private SockJsMessageCodec messageCodec; 089 090 @Nullable 091 private TaskScheduler connectTimeoutScheduler; 092 093 private volatile boolean running = false; 094 095 private final Map<URI, ServerInfo> serverInfoCache = new ConcurrentHashMap<>(); 096 097 098 /** 099 * Create a {@code SockJsClient} with the given transports. 100 * <p>If the list includes an {@link XhrTransport} (or more specifically an 101 * implementation of {@link InfoReceiver}) the instance is used to initialize 102 * the {@link #setInfoReceiver(InfoReceiver) infoReceiver} property, or 103 * otherwise is defaulted to {@link RestTemplateXhrTransport}. 104 * @param transports the (non-empty) list of transports to use 105 */ 106 public SockJsClient(List<Transport> transports) { 107 Assert.notEmpty(transports, "No transports provided"); 108 this.transports = new ArrayList<>(transports); 109 this.infoReceiver = initInfoReceiver(transports); 110 if (jackson2Present) { 111 this.messageCodec = new Jackson2SockJsMessageCodec(); 112 } 113 } 114 115 private static InfoReceiver initInfoReceiver(List<Transport> transports) { 116 for (Transport transport : transports) { 117 if (transport instanceof InfoReceiver) { 118 return ((InfoReceiver) transport); 119 } 120 } 121 return new RestTemplateXhrTransport(); 122 } 123 124 125 /** 126 * The names of HTTP headers that should be copied from the handshake headers 127 * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)} 128 * and also used with other HTTP requests issued as part of that SockJS 129 * connection, e.g. the initial info request, XHR send or receive requests. 130 * <p>By default if this property is not set, all handshake headers are also 131 * used for other HTTP requests. Set it if you want only a subset of handshake 132 * headers (e.g. auth headers) to be used for other HTTP requests. 133 * @param httpHeaderNames the HTTP header names 134 */ 135 public void setHttpHeaderNames(@Nullable String... httpHeaderNames) { 136 this.httpHeaderNames = httpHeaderNames; 137 } 138 139 /** 140 * The configured HTTP header names to be copied from the handshake 141 * headers and also included in other HTTP requests. 142 */ 143 @Nullable 144 public String[] getHttpHeaderNames() { 145 return this.httpHeaderNames; 146 } 147 148 /** 149 * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" 150 * request before the SockJS session starts. 151 * <p>If the list of transports provided to the constructor contained an 152 * {@link XhrTransport} or an implementation of {@link InfoReceiver} that 153 * instance would have been used to initialize this property, or otherwise 154 * it defaults to {@link RestTemplateXhrTransport}. 155 * @param infoReceiver the transport to use for the SockJS "Info" request 156 */ 157 public void setInfoReceiver(InfoReceiver infoReceiver) { 158 Assert.notNull(infoReceiver, "InfoReceiver is required"); 159 this.infoReceiver = infoReceiver; 160 } 161 162 /** 163 * Return the configured {@code InfoReceiver} (never {@code null}). 164 */ 165 public InfoReceiver getInfoReceiver() { 166 return this.infoReceiver; 167 } 168 169 /** 170 * Set the SockJsMessageCodec to use. 171 * <p>By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec 172 * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath. 173 */ 174 public void setMessageCodec(SockJsMessageCodec messageCodec) { 175 Assert.notNull(messageCodec, "SockJsMessageCodec is required"); 176 this.messageCodec = messageCodec; 177 } 178 179 /** 180 * Return the SockJsMessageCodec to use. 181 */ 182 public SockJsMessageCodec getMessageCodec() { 183 Assert.state(this.messageCodec != null, "No SockJsMessageCodec set"); 184 return this.messageCodec; 185 } 186 187 /** 188 * Configure a {@code TaskScheduler} for scheduling a connect timeout task 189 * where the timeout value is calculated based on the duration of the initial 190 * SockJS "Info" request. The connect timeout task ensures a more timely 191 * fallback but is otherwise entirely optional. 192 * <p>By default this is not configured in which case a fallback may take longer. 193 * @param connectTimeoutScheduler the task scheduler to use 194 */ 195 public void setConnectTimeoutScheduler(TaskScheduler connectTimeoutScheduler) { 196 this.connectTimeoutScheduler = connectTimeoutScheduler; 197 } 198 199 200 @Override 201 public void start() { 202 if (!isRunning()) { 203 this.running = true; 204 for (Transport transport : this.transports) { 205 if (transport instanceof Lifecycle) { 206 Lifecycle lifecycle = (Lifecycle) transport; 207 if (!lifecycle.isRunning()) { 208 lifecycle.start(); 209 } 210 } 211 } 212 } 213 } 214 215 @Override 216 public void stop() { 217 if (isRunning()) { 218 this.running = false; 219 for (Transport transport : this.transports) { 220 if (transport instanceof Lifecycle) { 221 Lifecycle lifecycle = (Lifecycle) transport; 222 if (lifecycle.isRunning()) { 223 lifecycle.stop(); 224 } 225 } 226 } 227 } 228 } 229 230 @Override 231 public boolean isRunning() { 232 return this.running; 233 } 234 235 236 @Override 237 public ListenableFuture<WebSocketSession> doHandshake( 238 WebSocketHandler handler, String uriTemplate, Object... uriVars) { 239 240 Assert.notNull(uriTemplate, "uriTemplate must not be null"); 241 URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri(); 242 return doHandshake(handler, null, uri); 243 } 244 245 @Override 246 public final ListenableFuture<WebSocketSession> doHandshake( 247 WebSocketHandler handler, @Nullable WebSocketHttpHeaders headers, URI url) { 248 249 Assert.notNull(handler, "WebSocketHandler is required"); 250 Assert.notNull(url, "URL is required"); 251 252 String scheme = url.getScheme(); 253 if (!supportedProtocols.contains(scheme)) { 254 throw new IllegalArgumentException("Invalid scheme: '" + scheme + "'"); 255 } 256 257 SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<>(); 258 try { 259 SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); 260 ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers)); 261 createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture); 262 } 263 catch (Exception exception) { 264 if (logger.isErrorEnabled()) { 265 logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception); 266 } 267 connectFuture.setException(exception); 268 } 269 return connectFuture; 270 } 271 272 @Nullable 273 private HttpHeaders getHttpRequestHeaders(@Nullable HttpHeaders webSocketHttpHeaders) { 274 if (getHttpHeaderNames() == null || webSocketHttpHeaders == null) { 275 return webSocketHttpHeaders; 276 } 277 else { 278 HttpHeaders httpHeaders = new HttpHeaders(); 279 for (String name : getHttpHeaderNames()) { 280 List<String> values = webSocketHttpHeaders.get(name); 281 if (values != null) { 282 httpHeaders.put(name, values); 283 } 284 } 285 return httpHeaders; 286 } 287 } 288 289 private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, @Nullable HttpHeaders headers) { 290 URI infoUrl = sockJsUrlInfo.getInfoUrl(); 291 ServerInfo info = this.serverInfoCache.get(infoUrl); 292 if (info == null) { 293 long start = System.currentTimeMillis(); 294 String response = this.infoReceiver.executeInfoRequest(infoUrl, headers); 295 long infoRequestTime = System.currentTimeMillis() - start; 296 info = new ServerInfo(response, infoRequestTime); 297 this.serverInfoCache.put(infoUrl, info); 298 } 299 return info; 300 } 301 302 private DefaultTransportRequest createRequest( 303 SockJsUrlInfo urlInfo, @Nullable HttpHeaders headers, ServerInfo serverInfo) { 304 305 List<DefaultTransportRequest> requests = new ArrayList<>(this.transports.size()); 306 for (Transport transport : this.transports) { 307 for (TransportType type : transport.getTransportTypes()) { 308 if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) { 309 requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers), 310 transport, type, getMessageCodec())); 311 } 312 } 313 } 314 if (CollectionUtils.isEmpty(requests)) { 315 throw new IllegalStateException( 316 "No transports: " + urlInfo + ", webSocketEnabled=" + serverInfo.isWebSocketEnabled()); 317 } 318 for (int i = 0; i < requests.size() - 1; i++) { 319 DefaultTransportRequest request = requests.get(i); 320 Principal user = getUser(); 321 if (user != null) { 322 request.setUser(user); 323 } 324 if (this.connectTimeoutScheduler != null) { 325 request.setTimeoutValue(serverInfo.getRetransmissionTimeout()); 326 request.setTimeoutScheduler(this.connectTimeoutScheduler); 327 } 328 request.setFallbackRequest(requests.get(i + 1)); 329 } 330 return requests.get(0); 331 } 332 333 /** 334 * Return the user to associate with the SockJS session and make available via 335 * {@link org.springframework.web.socket.WebSocketSession#getPrincipal()}. 336 * <p>By default this method returns {@code null}. 337 * @return the user to associate with the session (possibly {@code null}) 338 */ 339 @Nullable 340 protected Principal getUser() { 341 return null; 342 } 343 344 /** 345 * By default the result of a SockJS "Info" request, including whether the 346 * server has WebSocket disabled and how long the request took (used for 347 * calculating transport timeout time) is cached. This method can be used to 348 * clear that cache hence causing it to re-populate. 349 */ 350 public void clearServerInfoCache() { 351 this.serverInfoCache.clear(); 352 } 353 354 355 /** 356 * A simple value object holding the result from a SockJS "Info" request. 357 */ 358 private static class ServerInfo { 359 360 private final boolean webSocketEnabled; 361 362 private final long responseTime; 363 364 public ServerInfo(String response, long responseTime) { 365 this.responseTime = responseTime; 366 this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*"); 367 } 368 369 public boolean isWebSocketEnabled() { 370 return this.webSocketEnabled; 371 } 372 373 public long getRetransmissionTimeout() { 374 return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300); 375 } 376 } 377 378}