001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.server.support; 018 019import java.net.InetSocketAddress; 020import java.net.URI; 021import java.security.Principal; 022import java.util.Collections; 023import java.util.List; 024import java.util.Map; 025import java.util.function.Predicate; 026import java.util.stream.Collectors; 027 028import org.apache.commons.logging.Log; 029import org.apache.commons.logging.LogFactory; 030import reactor.core.publisher.Mono; 031 032import org.springframework.context.Lifecycle; 033import org.springframework.http.HttpHeaders; 034import org.springframework.http.HttpMethod; 035import org.springframework.http.server.reactive.ServerHttpRequest; 036import org.springframework.lang.Nullable; 037import org.springframework.util.Assert; 038import org.springframework.util.ClassUtils; 039import org.springframework.util.ReflectionUtils; 040import org.springframework.util.StringUtils; 041import org.springframework.web.reactive.socket.HandshakeInfo; 042import org.springframework.web.reactive.socket.WebSocketHandler; 043import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; 044import org.springframework.web.reactive.socket.server.WebSocketService; 045import org.springframework.web.server.MethodNotAllowedException; 046import org.springframework.web.server.ServerWebExchange; 047import org.springframework.web.server.ServerWebInputException; 048 049/** 050 * {@code WebSocketService} implementation that handles a WebSocket HTTP 051 * handshake request by delegating to a {@link RequestUpgradeStrategy} which 052 * is either auto-detected (no-arg constructor) from the classpath but can 053 * also be explicitly configured. 054 * 055 * @author Rossen Stoyanchev 056 * @since 5.0 057 */ 058public class HandshakeWebSocketService implements WebSocketService, Lifecycle { 059 060 private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; 061 062 private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; 063 064 private static final Mono<Map<String, Object>> EMPTY_ATTRIBUTES = Mono.just(Collections.emptyMap()); 065 066 067 private static final boolean tomcatPresent; 068 069 private static final boolean jettyPresent; 070 071 private static final boolean undertowPresent; 072 073 private static final boolean reactorNettyPresent; 074 075 static { 076 ClassLoader classLoader = HandshakeWebSocketService.class.getClassLoader(); 077 tomcatPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader); 078 jettyPresent = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader); 079 undertowPresent = ClassUtils.isPresent("io.undertow.websockets.WebSocketProtocolHandshakeHandler", classLoader); 080 reactorNettyPresent = ClassUtils.isPresent("reactor.netty.http.server.HttpServerResponse", classLoader); 081 } 082 083 084 protected static final Log logger = LogFactory.getLog(HandshakeWebSocketService.class); 085 086 087 private final RequestUpgradeStrategy upgradeStrategy; 088 089 @Nullable 090 private Predicate<String> sessionAttributePredicate; 091 092 private volatile boolean running = false; 093 094 095 /** 096 * Default constructor automatic, classpath detection based discovery of the 097 * {@link RequestUpgradeStrategy} to use. 098 */ 099 public HandshakeWebSocketService() { 100 this(initUpgradeStrategy()); 101 } 102 103 /** 104 * Alternative constructor with the {@link RequestUpgradeStrategy} to use. 105 * @param upgradeStrategy the strategy to use 106 */ 107 public HandshakeWebSocketService(RequestUpgradeStrategy upgradeStrategy) { 108 Assert.notNull(upgradeStrategy, "RequestUpgradeStrategy is required"); 109 this.upgradeStrategy = upgradeStrategy; 110 } 111 112 private static RequestUpgradeStrategy initUpgradeStrategy() { 113 String className; 114 if (tomcatPresent) { 115 className = "TomcatRequestUpgradeStrategy"; 116 } 117 else if (jettyPresent) { 118 className = "JettyRequestUpgradeStrategy"; 119 } 120 else if (undertowPresent) { 121 className = "UndertowRequestUpgradeStrategy"; 122 } 123 else if (reactorNettyPresent) { 124 // As late as possible (Reactor Netty commonly used for WebClient) 125 className = "ReactorNettyRequestUpgradeStrategy"; 126 } 127 else { 128 throw new IllegalStateException("No suitable default RequestUpgradeStrategy found"); 129 } 130 131 try { 132 className = "org.springframework.web.reactive.socket.server.upgrade." + className; 133 Class<?> clazz = ClassUtils.forName(className, HandshakeWebSocketService.class.getClassLoader()); 134 return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance(); 135 } 136 catch (Throwable ex) { 137 throw new IllegalStateException( 138 "Failed to instantiate RequestUpgradeStrategy: " + className, ex); 139 } 140 } 141 142 143 /** 144 * Return the {@link RequestUpgradeStrategy} for WebSocket requests. 145 */ 146 public RequestUpgradeStrategy getUpgradeStrategy() { 147 return this.upgradeStrategy; 148 } 149 150 /** 151 * Configure a predicate to use to extract 152 * {@link org.springframework.web.server.WebSession WebSession} attributes 153 * and use them to initialize the WebSocket session with. 154 * <p>By default this is not set in which case no attributes are passed. 155 * @param predicate the predicate 156 * @since 5.1 157 */ 158 public void setSessionAttributePredicate(@Nullable Predicate<String> predicate) { 159 this.sessionAttributePredicate = predicate; 160 } 161 162 /** 163 * Return the configured predicate for initialization WebSocket session 164 * attributes from {@code WebSession} attributes. 165 * @since 5.1 166 */ 167 @Nullable 168 public Predicate<String> getSessionAttributePredicate() { 169 return this.sessionAttributePredicate; 170 } 171 172 173 @Override 174 public void start() { 175 if (!isRunning()) { 176 this.running = true; 177 doStart(); 178 } 179 } 180 181 protected void doStart() { 182 if (getUpgradeStrategy() instanceof Lifecycle) { 183 ((Lifecycle) getUpgradeStrategy()).start(); 184 } 185 } 186 187 @Override 188 public void stop() { 189 if (isRunning()) { 190 this.running = false; 191 doStop(); 192 } 193 } 194 195 protected void doStop() { 196 if (getUpgradeStrategy() instanceof Lifecycle) { 197 ((Lifecycle) getUpgradeStrategy()).stop(); 198 } 199 } 200 201 @Override 202 public boolean isRunning() { 203 return this.running; 204 } 205 206 207 @Override 208 public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) { 209 ServerHttpRequest request = exchange.getRequest(); 210 HttpMethod method = request.getMethod(); 211 HttpHeaders headers = request.getHeaders(); 212 213 if (HttpMethod.GET != method) { 214 return Mono.error(new MethodNotAllowedException( 215 request.getMethodValue(), Collections.singleton(HttpMethod.GET))); 216 } 217 218 if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { 219 return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); 220 } 221 222 List<String> connectionValue = headers.getConnection(); 223 if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { 224 return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); 225 } 226 227 String key = headers.getFirst(SEC_WEBSOCKET_KEY); 228 if (key == null) { 229 return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); 230 } 231 232 String protocol = selectProtocol(headers, handler); 233 234 return initAttributes(exchange).flatMap(attributes -> 235 this.upgradeStrategy.upgrade(exchange, handler, protocol, 236 () -> createHandshakeInfo(exchange, request, protocol, attributes)) 237 ); 238 } 239 240 private Mono<Void> handleBadRequest(ServerWebExchange exchange, String reason) { 241 if (logger.isDebugEnabled()) { 242 logger.debug(exchange.getLogPrefix() + reason); 243 } 244 return Mono.error(new ServerWebInputException(reason)); 245 } 246 247 @Nullable 248 private String selectProtocol(HttpHeaders headers, WebSocketHandler handler) { 249 String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL); 250 if (protocolHeader != null) { 251 List<String> supportedProtocols = handler.getSubProtocols(); 252 for (String protocol : StringUtils.commaDelimitedListToStringArray(protocolHeader)) { 253 if (supportedProtocols.contains(protocol)) { 254 return protocol; 255 } 256 } 257 } 258 return null; 259 } 260 261 private Mono<Map<String, Object>> initAttributes(ServerWebExchange exchange) { 262 if (this.sessionAttributePredicate == null) { 263 return EMPTY_ATTRIBUTES; 264 } 265 return exchange.getSession().map(session -> 266 session.getAttributes().entrySet().stream() 267 .filter(entry -> this.sessionAttributePredicate.test(entry.getKey())) 268 .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); 269 } 270 271 private HandshakeInfo createHandshakeInfo(ServerWebExchange exchange, ServerHttpRequest request, 272 @Nullable String protocol, Map<String, Object> attributes) { 273 274 URI uri = request.getURI(); 275 // Copy request headers, as they might be pooled and recycled by 276 // the server implementation once the handshake HTTP exchange is done. 277 HttpHeaders headers = new HttpHeaders(); 278 headers.addAll(request.getHeaders()); 279 Mono<Principal> principal = exchange.getPrincipal(); 280 String logPrefix = exchange.getLogPrefix(); 281 InetSocketAddress remoteAddress = request.getRemoteAddress(); 282 return new HandshakeInfo(uri, headers, principal, protocol, remoteAddress, attributes, logPrefix); 283 } 284 285}