001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.support; 018 019import java.io.IOException; 020import java.nio.charset.Charset; 021import java.security.Principal; 022import java.util.ArrayList; 023import java.util.Arrays; 024import java.util.Collections; 025import java.util.List; 026import java.util.Map; 027 028import org.apache.commons.logging.Log; 029import org.apache.commons.logging.LogFactory; 030 031import org.springframework.context.Lifecycle; 032import org.springframework.http.HttpMethod; 033import org.springframework.http.HttpStatus; 034import org.springframework.http.server.ServerHttpRequest; 035import org.springframework.http.server.ServerHttpResponse; 036import org.springframework.util.Assert; 037import org.springframework.util.ClassUtils; 038import org.springframework.util.StringUtils; 039import org.springframework.web.socket.SubProtocolCapable; 040import org.springframework.web.socket.WebSocketExtension; 041import org.springframework.web.socket.WebSocketHandler; 042import org.springframework.web.socket.WebSocketHttpHeaders; 043import org.springframework.web.socket.handler.WebSocketHandlerDecorator; 044import org.springframework.web.socket.server.HandshakeFailureException; 045import org.springframework.web.socket.server.HandshakeHandler; 046import org.springframework.web.socket.server.RequestUpgradeStrategy; 047 048/** 049 * A base class for {@link HandshakeHandler} implementations, independent from the Servlet API. 050 * 051 * <p>Performs initial validation of the WebSocket handshake request - possibly rejecting it 052 * through the appropriate HTTP status code - while also allowing its subclasses to override 053 * various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation, 054 * extensions negotiation, etc). 055 * 056 * <p>If the negotiation succeeds, the actual upgrade is delegated to a server-specific 057 * {@link org.springframework.web.socket.server.RequestUpgradeStrategy}, which will update 058 * the response as necessary and initialize the WebSocket. Currently supported servers are 059 * Jetty 9.0-9.3, Tomcat 7.0.47+ and 8.x, Undertow 1.0-1.3, GlassFish 4.1+, WebLogic 12.1.3+. 060 * 061 * @author Rossen Stoyanchev 062 * @author Juergen Hoeller 063 * @since 4.2 064 * @see org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy 065 * @see org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy 066 * @see org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy 067 * @see org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy 068 * @see org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy 069 */ 070public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle { 071 072 private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); 073 074 075 private static final boolean jettyWsPresent = ClassUtils.isPresent( 076 "org.eclipse.jetty.websocket.server.WebSocketServerFactory", AbstractHandshakeHandler.class.getClassLoader()); 077 078 private static final boolean tomcatWsPresent = ClassUtils.isPresent( 079 "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader()); 080 081 private static final boolean undertowWsPresent = ClassUtils.isPresent( 082 "io.undertow.websockets.jsr.ServerWebSocketContainer", AbstractHandshakeHandler.class.getClassLoader()); 083 084 private static final boolean glassfishWsPresent = ClassUtils.isPresent( 085 "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader()); 086 087 private static final boolean weblogicWsPresent = ClassUtils.isPresent( 088 "weblogic.websocket.tyrus.TyrusServletWriter", AbstractHandshakeHandler.class.getClassLoader()); 089 090 private static final boolean websphereWsPresent = ClassUtils.isPresent( 091 "com.ibm.websphere.wsoc.WsWsocServerContainer", AbstractHandshakeHandler.class.getClassLoader()); 092 093 094 protected final Log logger = LogFactory.getLog(getClass()); 095 096 private final RequestUpgradeStrategy requestUpgradeStrategy; 097 098 private final List<String> supportedProtocols = new ArrayList<String>(); 099 100 private volatile boolean running = false; 101 102 103 /** 104 * Default constructor that auto-detects and instantiates a 105 * {@link RequestUpgradeStrategy} suitable for the runtime container. 106 * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found. 107 */ 108 protected AbstractHandshakeHandler() { 109 this(initRequestUpgradeStrategy()); 110 } 111 112 /** 113 * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}. 114 * @param requestUpgradeStrategy the upgrade strategy to use 115 */ 116 protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) { 117 Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null"); 118 this.requestUpgradeStrategy = requestUpgradeStrategy; 119 } 120 121 122 private static RequestUpgradeStrategy initRequestUpgradeStrategy() { 123 String className; 124 if (tomcatWsPresent) { 125 className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy"; 126 } 127 else if (jettyWsPresent) { 128 className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy"; 129 } 130 else if (undertowWsPresent) { 131 className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; 132 } 133 else if (glassfishWsPresent) { 134 className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy"; 135 } 136 else if (weblogicWsPresent) { 137 className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy"; 138 } 139 else if (websphereWsPresent) { 140 className = "org.springframework.web.socket.server.standard.WebSphereRequestUpgradeStrategy"; 141 } 142 else { 143 throw new IllegalStateException("No suitable default RequestUpgradeStrategy found"); 144 } 145 146 try { 147 Class<?> clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader()); 148 return (RequestUpgradeStrategy) clazz.newInstance(); 149 } 150 catch (Throwable ex) { 151 throw new IllegalStateException( 152 "Failed to instantiate RequestUpgradeStrategy: " + className, ex); 153 } 154 } 155 156 157 /** 158 * Return the {@link RequestUpgradeStrategy} for WebSocket requests. 159 */ 160 public RequestUpgradeStrategy getRequestUpgradeStrategy() { 161 return this.requestUpgradeStrategy; 162 } 163 164 /** 165 * Use this property to configure the list of supported sub-protocols. 166 * The first configured sub-protocol that matches a client-requested sub-protocol 167 * is accepted. If there are no matches the response will not contain a 168 * {@literal Sec-WebSocket-Protocol} header. 169 * <p>Note that if the WebSocketHandler passed in at runtime is an instance of 170 * {@link SubProtocolCapable} then there is not need to explicitly configure 171 * this property. That is certainly the case with the built-in STOMP over 172 * WebSocket support. Therefore this property should be configured explicitly 173 * only if the WebSocketHandler does not implement {@code SubProtocolCapable}. 174 */ 175 public void setSupportedProtocols(String... protocols) { 176 this.supportedProtocols.clear(); 177 for (String protocol : protocols) { 178 this.supportedProtocols.add(protocol.toLowerCase()); 179 } 180 } 181 182 /** 183 * Return the list of supported sub-protocols. 184 */ 185 public String[] getSupportedProtocols() { 186 return StringUtils.toStringArray(this.supportedProtocols); 187 } 188 189 190 @Override 191 public void start() { 192 if (!isRunning()) { 193 this.running = true; 194 doStart(); 195 } 196 } 197 198 protected void doStart() { 199 if (this.requestUpgradeStrategy instanceof Lifecycle) { 200 ((Lifecycle) this.requestUpgradeStrategy).start(); 201 } 202 } 203 204 @Override 205 public void stop() { 206 if (isRunning()) { 207 this.running = false; 208 doStop(); 209 } 210 } 211 212 protected void doStop() { 213 if (this.requestUpgradeStrategy instanceof Lifecycle) { 214 ((Lifecycle) this.requestUpgradeStrategy).stop(); 215 } 216 } 217 218 @Override 219 public boolean isRunning() { 220 return this.running; 221 } 222 223 224 @Override 225 public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, 226 WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException { 227 228 WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); 229 if (logger.isTraceEnabled()) { 230 logger.trace("Processing request " + request.getURI() + " with headers=" + headers); 231 } 232 try { 233 if (HttpMethod.GET != request.getMethod()) { 234 response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); 235 response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET)); 236 if (logger.isErrorEnabled()) { 237 logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod()); 238 } 239 return false; 240 } 241 if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { 242 handleInvalidUpgradeHeader(request, response); 243 return false; 244 } 245 if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { 246 handleInvalidConnectHeader(request, response); 247 return false; 248 } 249 if (!isWebSocketVersionSupported(headers)) { 250 handleWebSocketVersionNotSupported(request, response); 251 return false; 252 } 253 if (!isValidOrigin(request)) { 254 response.setStatusCode(HttpStatus.FORBIDDEN); 255 return false; 256 } 257 String wsKey = headers.getSecWebSocketKey(); 258 if (wsKey == null) { 259 if (logger.isErrorEnabled()) { 260 logger.error("Missing \"Sec-WebSocket-Key\" header"); 261 } 262 response.setStatusCode(HttpStatus.BAD_REQUEST); 263 return false; 264 } 265 } 266 catch (IOException ex) { 267 throw new HandshakeFailureException( 268 "Response update failed during upgrade to WebSocket: " + request.getURI(), ex); 269 } 270 271 String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler); 272 List<WebSocketExtension> requested = headers.getSecWebSocketExtensions(); 273 List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request); 274 List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported); 275 Principal user = determineUser(request, wsHandler, attributes); 276 277 if (logger.isTraceEnabled()) { 278 logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); 279 } 280 this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); 281 return true; 282 } 283 284 protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { 285 if (logger.isErrorEnabled()) { 286 logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade()); 287 } 288 response.setStatusCode(HttpStatus.BAD_REQUEST); 289 response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET)); 290 } 291 292 protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { 293 if (logger.isErrorEnabled()) { 294 logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection()); 295 } 296 response.setStatusCode(HttpStatus.BAD_REQUEST); 297 response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET)); 298 } 299 300 protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { 301 String version = httpHeaders.getSecWebSocketVersion(); 302 String[] supportedVersions = getSupportedVersions(); 303 for (String supportedVersion : supportedVersions) { 304 if (supportedVersion.trim().equals(version)) { 305 return true; 306 } 307 } 308 return false; 309 } 310 311 protected String[] getSupportedVersions() { 312 return this.requestUpgradeStrategy.getSupportedVersions(); 313 } 314 315 protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { 316 if (logger.isErrorEnabled()) { 317 String version = request.getHeaders().getFirst("Sec-WebSocket-Version"); 318 logger.error("Handshake failed due to unsupported WebSocket version: " + version + 319 ". Supported versions: " + Arrays.toString(getSupportedVersions())); 320 } 321 response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); 322 response.getHeaders().set(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, 323 StringUtils.arrayToCommaDelimitedString(getSupportedVersions())); 324 } 325 326 /** 327 * Return whether the request {@code Origin} header value is valid or not. 328 * By default, all origins as considered as valid. Consider using an 329 * {@link OriginHandshakeInterceptor} for filtering origins if needed. 330 */ 331 protected boolean isValidOrigin(ServerHttpRequest request) { 332 return true; 333 } 334 335 /** 336 * Perform the sub-protocol negotiation based on requested and supported sub-protocols. 337 * For the list of supported sub-protocols, this method first checks if the target 338 * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any 339 * sub-protocols have been explicitly configured with 340 * {@link #setSupportedProtocols(String...)}. 341 * @param requestedProtocols the requested sub-protocols 342 * @param webSocketHandler the WebSocketHandler that will be used 343 * @return the selected protocols or {@code null} 344 * @see #determineHandlerSupportedProtocols(WebSocketHandler) 345 */ 346 protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) { 347 if (requestedProtocols != null) { 348 List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler); 349 for (String protocol : requestedProtocols) { 350 if (handlerProtocols.contains(protocol.toLowerCase())) { 351 return protocol; 352 } 353 if (this.supportedProtocols.contains(protocol.toLowerCase())) { 354 return protocol; 355 } 356 } 357 } 358 return null; 359 } 360 361 /** 362 * Determine the sub-protocols supported by the given WebSocketHandler by 363 * checking whether it is an instance of {@link SubProtocolCapable}. 364 * @param handler the handler to check 365 * @return a list of supported protocols, or an empty list if none available 366 */ 367 protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) { 368 WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler); 369 List<String> subProtocols = null; 370 if (handlerToCheck instanceof SubProtocolCapable) { 371 subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols(); 372 } 373 return (subProtocols != null ? subProtocols : Collections.<String>emptyList()); 374 } 375 376 /** 377 * Filter the list of requested WebSocket extensions. 378 * <p>As of 4.1, the default implementation of this method filters the list to 379 * leave only extensions that are both requested and supported. 380 * @param request the current request 381 * @param requestedExtensions the list of extensions requested by the client 382 * @param supportedExtensions the list of extensions supported by the server 383 * @return the selected extensions or an empty list 384 */ 385 protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request, 386 List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) { 387 388 List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(requestedExtensions.size()); 389 for (WebSocketExtension extension : requestedExtensions) { 390 if (supportedExtensions.contains(extension)) { 391 result.add(extension); 392 } 393 } 394 return result; 395 } 396 397 /** 398 * A method that can be used to associate a user with the WebSocket session 399 * in the process of being established. The default implementation calls 400 * {@link ServerHttpRequest#getPrincipal()} 401 * <p>Subclasses can provide custom logic for associating a user with a session, 402 * for example for assigning a name to anonymous users (i.e. not fully authenticated). 403 * @param request the handshake request 404 * @param wsHandler the WebSocket handler that will handle messages 405 * @param attributes handshake attributes to pass to the WebSocket session 406 * @return the user for the WebSocket session, or {@code null} if not available 407 */ 408 protected Principal determineUser( 409 ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) { 410 411 return request.getPrincipal(); 412 } 413 414}