001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.support; 018 019import java.io.IOException; 020import java.nio.charset.StandardCharsets; 021import java.security.Principal; 022import java.util.ArrayList; 023import java.util.Arrays; 024import java.util.Collections; 025import java.util.List; 026import java.util.Map; 027 028import org.apache.commons.logging.Log; 029import org.apache.commons.logging.LogFactory; 030 031import org.springframework.context.Lifecycle; 032import org.springframework.http.HttpMethod; 033import org.springframework.http.HttpStatus; 034import org.springframework.http.server.ServerHttpRequest; 035import org.springframework.http.server.ServerHttpResponse; 036import org.springframework.lang.Nullable; 037import org.springframework.util.Assert; 038import org.springframework.util.ClassUtils; 039import org.springframework.util.ReflectionUtils; 040import org.springframework.util.StringUtils; 041import org.springframework.web.socket.SubProtocolCapable; 042import org.springframework.web.socket.WebSocketExtension; 043import org.springframework.web.socket.WebSocketHandler; 044import org.springframework.web.socket.WebSocketHttpHeaders; 045import org.springframework.web.socket.handler.WebSocketHandlerDecorator; 046import org.springframework.web.socket.server.HandshakeFailureException; 047import org.springframework.web.socket.server.HandshakeHandler; 048import org.springframework.web.socket.server.RequestUpgradeStrategy; 049 050/** 051 * A base class for {@link HandshakeHandler} implementations, independent from the Servlet API. 052 * 053 * <p>Performs initial validation of the WebSocket handshake request - possibly rejecting it 054 * through the appropriate HTTP status code - while also allowing its subclasses to override 055 * various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation, 056 * extensions negotiation, etc). 057 * 058 * <p>If the negotiation succeeds, the actual upgrade is delegated to a server-specific 059 * {@link org.springframework.web.socket.server.RequestUpgradeStrategy}, which will update 060 * the response as necessary and initialize the WebSocket. Currently supported servers are 061 * Jetty 9.0-9.3, Tomcat 7.0.47+ and 8.x, Undertow 1.0-1.3, GlassFish 4.1+, WebLogic 12.1.3+. 062 * 063 * @author Rossen Stoyanchev 064 * @author Juergen Hoeller 065 * @since 4.2 066 * @see org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy 067 * @see org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy 068 * @see org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy 069 * @see org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy 070 * @see org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy 071 */ 072public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle { 073 074 private static final boolean jettyWsPresent; 075 076 private static final boolean tomcatWsPresent; 077 078 private static final boolean undertowWsPresent; 079 080 private static final boolean glassfishWsPresent; 081 082 private static final boolean weblogicWsPresent; 083 084 private static final boolean websphereWsPresent; 085 086 static { 087 ClassLoader classLoader = AbstractHandshakeHandler.class.getClassLoader(); 088 jettyWsPresent = ClassUtils.isPresent( 089 "org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader); 090 tomcatWsPresent = ClassUtils.isPresent( 091 "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader); 092 undertowWsPresent = ClassUtils.isPresent( 093 "io.undertow.websockets.jsr.ServerWebSocketContainer", classLoader); 094 glassfishWsPresent = ClassUtils.isPresent( 095 "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", classLoader); 096 weblogicWsPresent = ClassUtils.isPresent( 097 "weblogic.websocket.tyrus.TyrusServletWriter", classLoader); 098 websphereWsPresent = ClassUtils.isPresent( 099 "com.ibm.websphere.wsoc.WsWsocServerContainer", classLoader); 100 101 } 102 103 104 protected final Log logger = LogFactory.getLog(getClass()); 105 106 private final RequestUpgradeStrategy requestUpgradeStrategy; 107 108 private final List<String> supportedProtocols = new ArrayList<>(); 109 110 private volatile boolean running = false; 111 112 113 /** 114 * Default constructor that auto-detects and instantiates a 115 * {@link RequestUpgradeStrategy} suitable for the runtime container. 116 * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found. 117 */ 118 protected AbstractHandshakeHandler() { 119 this(initRequestUpgradeStrategy()); 120 } 121 122 /** 123 * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}. 124 * @param requestUpgradeStrategy the upgrade strategy to use 125 */ 126 protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) { 127 Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null"); 128 this.requestUpgradeStrategy = requestUpgradeStrategy; 129 } 130 131 132 private static RequestUpgradeStrategy initRequestUpgradeStrategy() { 133 String className; 134 if (tomcatWsPresent) { 135 className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy"; 136 } 137 else if (jettyWsPresent) { 138 className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy"; 139 } 140 else if (undertowWsPresent) { 141 className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; 142 } 143 else if (glassfishWsPresent) { 144 className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy"; 145 } 146 else if (weblogicWsPresent) { 147 className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy"; 148 } 149 else if (websphereWsPresent) { 150 className = "org.springframework.web.socket.server.standard.WebSphereRequestUpgradeStrategy"; 151 } 152 else { 153 throw new IllegalStateException("No suitable default RequestUpgradeStrategy found"); 154 } 155 156 try { 157 Class<?> clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader()); 158 return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance(); 159 } 160 catch (Exception ex) { 161 throw new IllegalStateException( 162 "Failed to instantiate RequestUpgradeStrategy: " + className, ex); 163 } 164 } 165 166 167 /** 168 * Return the {@link RequestUpgradeStrategy} for WebSocket requests. 169 */ 170 public RequestUpgradeStrategy getRequestUpgradeStrategy() { 171 return this.requestUpgradeStrategy; 172 } 173 174 /** 175 * Use this property to configure the list of supported sub-protocols. 176 * The first configured sub-protocol that matches a client-requested sub-protocol 177 * is accepted. If there are no matches the response will not contain a 178 * {@literal Sec-WebSocket-Protocol} header. 179 * <p>Note that if the WebSocketHandler passed in at runtime is an instance of 180 * {@link SubProtocolCapable} then there is not need to explicitly configure 181 * this property. That is certainly the case with the built-in STOMP over 182 * WebSocket support. Therefore this property should be configured explicitly 183 * only if the WebSocketHandler does not implement {@code SubProtocolCapable}. 184 */ 185 public void setSupportedProtocols(String... protocols) { 186 this.supportedProtocols.clear(); 187 for (String protocol : protocols) { 188 this.supportedProtocols.add(protocol.toLowerCase()); 189 } 190 } 191 192 /** 193 * Return the list of supported sub-protocols. 194 */ 195 public String[] getSupportedProtocols() { 196 return StringUtils.toStringArray(this.supportedProtocols); 197 } 198 199 200 @Override 201 public void start() { 202 if (!isRunning()) { 203 this.running = true; 204 doStart(); 205 } 206 } 207 208 protected void doStart() { 209 if (this.requestUpgradeStrategy instanceof Lifecycle) { 210 ((Lifecycle) this.requestUpgradeStrategy).start(); 211 } 212 } 213 214 @Override 215 public void stop() { 216 if (isRunning()) { 217 this.running = false; 218 doStop(); 219 } 220 } 221 222 protected void doStop() { 223 if (this.requestUpgradeStrategy instanceof Lifecycle) { 224 ((Lifecycle) this.requestUpgradeStrategy).stop(); 225 } 226 } 227 228 @Override 229 public boolean isRunning() { 230 return this.running; 231 } 232 233 234 @Override 235 public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, 236 WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException { 237 238 WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders()); 239 if (logger.isTraceEnabled()) { 240 logger.trace("Processing request " + request.getURI() + " with headers=" + headers); 241 } 242 try { 243 if (HttpMethod.GET != request.getMethod()) { 244 response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); 245 response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET)); 246 if (logger.isErrorEnabled()) { 247 logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod()); 248 } 249 return false; 250 } 251 if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { 252 handleInvalidUpgradeHeader(request, response); 253 return false; 254 } 255 if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { 256 handleInvalidConnectHeader(request, response); 257 return false; 258 } 259 if (!isWebSocketVersionSupported(headers)) { 260 handleWebSocketVersionNotSupported(request, response); 261 return false; 262 } 263 if (!isValidOrigin(request)) { 264 response.setStatusCode(HttpStatus.FORBIDDEN); 265 return false; 266 } 267 String wsKey = headers.getSecWebSocketKey(); 268 if (wsKey == null) { 269 if (logger.isErrorEnabled()) { 270 logger.error("Missing \"Sec-WebSocket-Key\" header"); 271 } 272 response.setStatusCode(HttpStatus.BAD_REQUEST); 273 return false; 274 } 275 } 276 catch (IOException ex) { 277 throw new HandshakeFailureException( 278 "Response update failed during upgrade to WebSocket: " + request.getURI(), ex); 279 } 280 281 String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler); 282 List<WebSocketExtension> requested = headers.getSecWebSocketExtensions(); 283 List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request); 284 List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported); 285 Principal user = determineUser(request, wsHandler, attributes); 286 287 if (logger.isTraceEnabled()) { 288 logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); 289 } 290 this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); 291 return true; 292 } 293 294 protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { 295 if (logger.isErrorEnabled()) { 296 logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade()); 297 } 298 response.setStatusCode(HttpStatus.BAD_REQUEST); 299 response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(StandardCharsets.UTF_8)); 300 } 301 302 protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { 303 if (logger.isErrorEnabled()) { 304 logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection()); 305 } 306 response.setStatusCode(HttpStatus.BAD_REQUEST); 307 response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(StandardCharsets.UTF_8)); 308 } 309 310 protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) { 311 String version = httpHeaders.getSecWebSocketVersion(); 312 String[] supportedVersions = getSupportedVersions(); 313 for (String supportedVersion : supportedVersions) { 314 if (supportedVersion.trim().equals(version)) { 315 return true; 316 } 317 } 318 return false; 319 } 320 321 protected String[] getSupportedVersions() { 322 return this.requestUpgradeStrategy.getSupportedVersions(); 323 } 324 325 protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { 326 if (logger.isErrorEnabled()) { 327 String version = request.getHeaders().getFirst("Sec-WebSocket-Version"); 328 logger.error("Handshake failed due to unsupported WebSocket version: " + version + 329 ". Supported versions: " + Arrays.toString(getSupportedVersions())); 330 } 331 response.setStatusCode(HttpStatus.UPGRADE_REQUIRED); 332 response.getHeaders().set(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION, 333 StringUtils.arrayToCommaDelimitedString(getSupportedVersions())); 334 } 335 336 /** 337 * Return whether the request {@code Origin} header value is valid or not. 338 * By default, all origins as considered as valid. Consider using an 339 * {@link OriginHandshakeInterceptor} for filtering origins if needed. 340 */ 341 protected boolean isValidOrigin(ServerHttpRequest request) { 342 return true; 343 } 344 345 /** 346 * Perform the sub-protocol negotiation based on requested and supported sub-protocols. 347 * For the list of supported sub-protocols, this method first checks if the target 348 * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any 349 * sub-protocols have been explicitly configured with 350 * {@link #setSupportedProtocols(String...)}. 351 * @param requestedProtocols the requested sub-protocols 352 * @param webSocketHandler the WebSocketHandler that will be used 353 * @return the selected protocols or {@code null} 354 * @see #determineHandlerSupportedProtocols(WebSocketHandler) 355 */ 356 @Nullable 357 protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) { 358 List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler); 359 for (String protocol : requestedProtocols) { 360 if (handlerProtocols.contains(protocol.toLowerCase())) { 361 return protocol; 362 } 363 if (this.supportedProtocols.contains(protocol.toLowerCase())) { 364 return protocol; 365 } 366 } 367 return null; 368 } 369 370 /** 371 * Determine the sub-protocols supported by the given WebSocketHandler by 372 * checking whether it is an instance of {@link SubProtocolCapable}. 373 * @param handler the handler to check 374 * @return a list of supported protocols, or an empty list if none available 375 */ 376 protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) { 377 WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler); 378 List<String> subProtocols = null; 379 if (handlerToCheck instanceof SubProtocolCapable) { 380 subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols(); 381 } 382 return (subProtocols != null ? subProtocols : Collections.emptyList()); 383 } 384 385 /** 386 * Filter the list of requested WebSocket extensions. 387 * <p>As of 4.1, the default implementation of this method filters the list to 388 * leave only extensions that are both requested and supported. 389 * @param request the current request 390 * @param requestedExtensions the list of extensions requested by the client 391 * @param supportedExtensions the list of extensions supported by the server 392 * @return the selected extensions or an empty list 393 */ 394 protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request, 395 List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) { 396 397 List<WebSocketExtension> result = new ArrayList<>(requestedExtensions.size()); 398 for (WebSocketExtension extension : requestedExtensions) { 399 if (supportedExtensions.contains(extension)) { 400 result.add(extension); 401 } 402 } 403 return result; 404 } 405 406 /** 407 * A method that can be used to associate a user with the WebSocket session 408 * in the process of being established. The default implementation calls 409 * {@link ServerHttpRequest#getPrincipal()} 410 * <p>Subclasses can provide custom logic for associating a user with a session, 411 * for example for assigning a name to anonymous users (i.e. not fully authenticated). 412 * @param request the handshake request 413 * @param wsHandler the WebSocket handler that will handle messages 414 * @param attributes handshake attributes to pass to the WebSocket session 415 * @return the user for the WebSocket session, or {@code null} if not available 416 */ 417 @Nullable 418 protected Principal determineUser( 419 ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) { 420 421 return request.getPrincipal(); 422 } 423 424}