001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.io.IOException; 020import java.lang.reflect.Constructor; 021import java.lang.reflect.Method; 022import java.util.Collections; 023import java.util.List; 024import java.util.Map; 025import java.util.Set; 026import java.util.concurrent.ConcurrentHashMap; 027import javax.servlet.ServletException; 028import javax.servlet.http.HttpServletRequest; 029import javax.servlet.http.HttpServletResponse; 030import javax.websocket.Decoder; 031import javax.websocket.Encoder; 032import javax.websocket.Endpoint; 033import javax.websocket.Extension; 034import javax.websocket.server.ServerEndpointConfig; 035 036import io.undertow.server.HttpServerExchange; 037import io.undertow.server.HttpUpgradeListener; 038import io.undertow.servlet.api.InstanceFactory; 039import io.undertow.servlet.api.InstanceHandle; 040import io.undertow.servlet.websockets.ServletWebSocketHttpExchange; 041import io.undertow.util.PathTemplate; 042import io.undertow.websockets.core.WebSocketChannel; 043import io.undertow.websockets.core.WebSocketVersion; 044import io.undertow.websockets.core.protocol.Handshake; 045import io.undertow.websockets.jsr.ConfiguredServerEndpoint; 046import io.undertow.websockets.jsr.EncodingFactory; 047import io.undertow.websockets.jsr.EndpointSessionHandler; 048import io.undertow.websockets.jsr.ServerWebSocketContainer; 049import io.undertow.websockets.jsr.annotated.AnnotatedEndpointFactory; 050import io.undertow.websockets.jsr.handshake.HandshakeUtil; 051import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake; 052import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake; 053import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake; 054import io.undertow.websockets.spi.WebSocketHttpExchange; 055import org.xnio.StreamConnection; 056 057import org.springframework.http.server.ServerHttpRequest; 058import org.springframework.http.server.ServerHttpResponse; 059import org.springframework.util.ClassUtils; 060import org.springframework.util.ReflectionUtils; 061import org.springframework.web.socket.server.HandshakeFailureException; 062 063/** 064 * A WebSocket {@code RequestUpgradeStrategy} for WildFly and its underlying 065 * Undertow web server. Also compatible with embedded Undertow usage. 066 * 067 * <p>Designed for Undertow 1.3.5+ as of Spring Framework 4.3, with a fallback 068 * strategy for Undertow 1.0 to 1.3 - as included in WildFly 8.x, 9 and 10. 069 * 070 * @author Rossen Stoyanchev 071 * @since 4.0.1 072 */ 073public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { 074 075 private static final boolean HAS_DO_UPGRADE = ClassUtils.hasMethod(ServerWebSocketContainer.class, "doUpgrade", 076 HttpServletRequest.class, HttpServletResponse.class, ServerEndpointConfig.class, Map.class); 077 078 private static final FallbackStrategy FALLBACK_STRATEGY = (HAS_DO_UPGRADE ? null : new FallbackStrategy()); 079 080 private static final String[] VERSIONS = new String[] { 081 WebSocketVersion.V13.toHttpHeaderValue(), 082 WebSocketVersion.V08.toHttpHeaderValue(), 083 WebSocketVersion.V07.toHttpHeaderValue() 084 }; 085 086 087 @Override 088 public String[] getSupportedVersions() { 089 return VERSIONS; 090 } 091 092 @Override 093 protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 094 String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint) 095 throws HandshakeFailureException { 096 097 if (HAS_DO_UPGRADE) { 098 HttpServletRequest servletRequest = getHttpServletRequest(request); 099 HttpServletResponse servletResponse = getHttpServletResponse(response); 100 101 StringBuffer requestUrl = servletRequest.getRequestURL(); 102 String path = servletRequest.getRequestURI(); // shouldn't matter 103 Map<String, String> pathParams = Collections.<String, String>emptyMap(); 104 105 ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(path, endpoint); 106 endpointConfig.setSubprotocols(Collections.singletonList(selectedProtocol)); 107 endpointConfig.setExtensions(selectedExtensions); 108 109 try { 110 getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, endpointConfig, pathParams); 111 } 112 catch (ServletException ex) { 113 throw new HandshakeFailureException( 114 "Servlet request failed to upgrade to WebSocket: " + requestUrl, ex); 115 } 116 catch (IOException ex) { 117 throw new HandshakeFailureException( 118 "Response update failed during upgrade to WebSocket: " + requestUrl, ex); 119 } 120 } 121 else { 122 FALLBACK_STRATEGY.upgradeInternal(request, response, selectedProtocol, selectedExtensions, endpoint); 123 } 124 } 125 126 @Override 127 public ServerWebSocketContainer getContainer(HttpServletRequest request) { 128 return (ServerWebSocketContainer) super.getContainer(request); 129 } 130 131 132 /** 133 * Strategy for use with Undertow 1.0 to 1.3 before there was a public API 134 * to perform a WebSocket upgrade. 135 */ 136 private static class FallbackStrategy extends AbstractStandardUpgradeStrategy { 137 138 private static final Constructor<ServletWebSocketHttpExchange> exchangeConstructor; 139 140 private static final boolean exchangeConstructorWithPeerConnections; 141 142 private static final Constructor<ConfiguredServerEndpoint> endpointConstructor; 143 144 private static final boolean endpointConstructorWithEndpointFactory; 145 146 private static final Method getBufferPoolMethod; 147 148 private static final Method createChannelMethod; 149 150 static { 151 try { 152 Class<ServletWebSocketHttpExchange> exchangeType = ServletWebSocketHttpExchange.class; 153 Class<?>[] exchangeParamTypes = 154 new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class, Set.class}; 155 Constructor<ServletWebSocketHttpExchange> exchangeCtor = 156 ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes); 157 if (exchangeCtor != null) { 158 // Undertow 1.1+ 159 exchangeConstructor = exchangeCtor; 160 exchangeConstructorWithPeerConnections = true; 161 } 162 else { 163 // Undertow 1.0 164 exchangeParamTypes = new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class}; 165 exchangeConstructor = exchangeType.getConstructor(exchangeParamTypes); 166 exchangeConstructorWithPeerConnections = false; 167 } 168 169 Class<ConfiguredServerEndpoint> endpointType = ConfiguredServerEndpoint.class; 170 Class<?>[] endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class, 171 PathTemplate.class, EncodingFactory.class, AnnotatedEndpointFactory.class}; 172 Constructor<ConfiguredServerEndpoint> endpointCtor = 173 ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes); 174 if (endpointCtor != null) { 175 // Undertow 1.1+ 176 endpointConstructor = endpointCtor; 177 endpointConstructorWithEndpointFactory = true; 178 } 179 else { 180 // Undertow 1.0 181 endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class, 182 PathTemplate.class, EncodingFactory.class}; 183 endpointConstructor = endpointType.getConstructor(endpointParamTypes); 184 endpointConstructorWithEndpointFactory = false; 185 } 186 187 // Adapting between different Pool API types in Undertow 1.0-1.2 vs 1.3 188 getBufferPoolMethod = WebSocketHttpExchange.class.getMethod("getBufferPool"); 189 createChannelMethod = ReflectionUtils.findMethod(Handshake.class, "createChannel", (Class<?>[]) null); 190 } 191 catch (Throwable ex) { 192 throw new IllegalStateException("Incompatible Undertow API version", ex); 193 } 194 } 195 196 private final Set<WebSocketChannel> peerConnections; 197 198 public FallbackStrategy() { 199 if (exchangeConstructorWithPeerConnections) { 200 this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>()); 201 } 202 else { 203 this.peerConnections = null; 204 } 205 } 206 207 @Override 208 public String[] getSupportedVersions() { 209 return VERSIONS; 210 } 211 212 @Override 213 protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 214 String selectedProtocol, List<Extension> selectedExtensions, final Endpoint endpoint) 215 throws HandshakeFailureException { 216 217 HttpServletRequest servletRequest = getHttpServletRequest(request); 218 HttpServletResponse servletResponse = getHttpServletResponse(response); 219 220 final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse); 221 exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap()); 222 223 ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest); 224 final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer); 225 226 final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint( 227 selectedProtocol, selectedExtensions, endpoint, servletRequest); 228 229 final Handshake handshake = getHandshakeToUse(exchange, configuredServerEndpoint); 230 231 exchange.upgradeChannel(new HttpUpgradeListener() { 232 @Override 233 public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) { 234 Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange); 235 WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod( 236 createChannelMethod, handshake, exchange, connection, bufferPool); 237 if (peerConnections != null) { 238 peerConnections.add(channel); 239 } 240 endpointSessionHandler.onConnect(exchange, channel); 241 } 242 }); 243 244 handshake.handshake(exchange); 245 } 246 247 private ServletWebSocketHttpExchange createHttpExchange(HttpServletRequest request, HttpServletResponse response) { 248 try { 249 return (this.peerConnections != null ? 250 exchangeConstructor.newInstance(request, response, this.peerConnections) : 251 exchangeConstructor.newInstance(request, response)); 252 } 253 catch (Exception ex) { 254 throw new HandshakeFailureException("Failed to instantiate ServletWebSocketHttpExchange", ex); 255 } 256 } 257 258 private Handshake getHandshakeToUse(ServletWebSocketHttpExchange exchange, ConfiguredServerEndpoint endpoint) { 259 Handshake handshake = new JsrHybi13Handshake(endpoint); 260 if (handshake.matches(exchange)) { 261 return handshake; 262 } 263 handshake = new JsrHybi08Handshake(endpoint); 264 if (handshake.matches(exchange)) { 265 return handshake; 266 } 267 handshake = new JsrHybi07Handshake(endpoint); 268 if (handshake.matches(exchange)) { 269 return handshake; 270 } 271 // Should never occur 272 throw new HandshakeFailureException("No matching Undertow Handshake found: " + exchange.getRequestHeaders()); 273 } 274 275 private ConfiguredServerEndpoint createConfiguredServerEndpoint(String selectedProtocol, 276 List<Extension> selectedExtensions, Endpoint endpoint, HttpServletRequest servletRequest) { 277 278 String path = servletRequest.getRequestURI(); // shouldn't matter 279 ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration(path, endpoint); 280 endpointRegistration.setSubprotocols(Collections.singletonList(selectedProtocol)); 281 endpointRegistration.setExtensions(selectedExtensions); 282 283 EncodingFactory encodingFactory = new EncodingFactory( 284 Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(), 285 Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap(), 286 Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(), 287 Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap()); 288 try { 289 return (endpointConstructorWithEndpointFactory ? 290 endpointConstructor.newInstance(endpointRegistration, 291 new EndpointInstanceFactory(endpoint), null, encodingFactory, null) : 292 endpointConstructor.newInstance(endpointRegistration, 293 new EndpointInstanceFactory(endpoint), null, encodingFactory)); 294 } 295 catch (Exception ex) { 296 throw new HandshakeFailureException("Failed to instantiate ConfiguredServerEndpoint", ex); 297 } 298 } 299 300 301 private static class EndpointInstanceFactory implements InstanceFactory<Endpoint> { 302 303 private final Endpoint endpoint; 304 305 public EndpointInstanceFactory(Endpoint endpoint) { 306 this.endpoint = endpoint; 307 } 308 309 @Override 310 public InstanceHandle<Endpoint> createInstance() throws InstantiationException { 311 return new InstanceHandle<Endpoint>() { 312 @Override 313 public Endpoint getInstance() { 314 return endpoint; 315 } 316 @Override 317 public void release() { 318 } 319 }; 320 } 321 } 322 } 323 324}