001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.io.IOException;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.Method;
022import java.util.Collections;
023import java.util.List;
024import java.util.Map;
025import java.util.Set;
026import java.util.concurrent.ConcurrentHashMap;
027import javax.servlet.ServletException;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpServletResponse;
030import javax.websocket.Decoder;
031import javax.websocket.Encoder;
032import javax.websocket.Endpoint;
033import javax.websocket.Extension;
034import javax.websocket.server.ServerEndpointConfig;
035
036import io.undertow.server.HttpServerExchange;
037import io.undertow.server.HttpUpgradeListener;
038import io.undertow.servlet.api.InstanceFactory;
039import io.undertow.servlet.api.InstanceHandle;
040import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
041import io.undertow.util.PathTemplate;
042import io.undertow.websockets.core.WebSocketChannel;
043import io.undertow.websockets.core.WebSocketVersion;
044import io.undertow.websockets.core.protocol.Handshake;
045import io.undertow.websockets.jsr.ConfiguredServerEndpoint;
046import io.undertow.websockets.jsr.EncodingFactory;
047import io.undertow.websockets.jsr.EndpointSessionHandler;
048import io.undertow.websockets.jsr.ServerWebSocketContainer;
049import io.undertow.websockets.jsr.annotated.AnnotatedEndpointFactory;
050import io.undertow.websockets.jsr.handshake.HandshakeUtil;
051import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake;
052import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake;
053import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake;
054import io.undertow.websockets.spi.WebSocketHttpExchange;
055import org.xnio.StreamConnection;
056
057import org.springframework.http.server.ServerHttpRequest;
058import org.springframework.http.server.ServerHttpResponse;
059import org.springframework.util.ClassUtils;
060import org.springframework.util.ReflectionUtils;
061import org.springframework.web.socket.server.HandshakeFailureException;
062
063/**
064 * A WebSocket {@code RequestUpgradeStrategy} for WildFly and its underlying
065 * Undertow web server. Also compatible with embedded Undertow usage.
066 *
067 * <p>Designed for Undertow 1.3.5+ as of Spring Framework 4.3, with a fallback
068 * strategy for Undertow 1.0 to 1.3 - as included in WildFly 8.x, 9 and 10.
069 *
070 * @author Rossen Stoyanchev
071 * @since 4.0.1
072 */
073public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
074
075        private static final boolean HAS_DO_UPGRADE = ClassUtils.hasMethod(ServerWebSocketContainer.class, "doUpgrade",
076                        HttpServletRequest.class, HttpServletResponse.class, ServerEndpointConfig.class, Map.class);
077
078        private static final FallbackStrategy FALLBACK_STRATEGY = (HAS_DO_UPGRADE ? null : new FallbackStrategy());
079
080        private static final String[] VERSIONS = new String[] {
081                        WebSocketVersion.V13.toHttpHeaderValue(),
082                        WebSocketVersion.V08.toHttpHeaderValue(),
083                        WebSocketVersion.V07.toHttpHeaderValue()
084        };
085
086
087        @Override
088        public String[] getSupportedVersions() {
089                return VERSIONS;
090        }
091
092        @Override
093        protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
094                        String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint)
095                        throws HandshakeFailureException {
096
097                if (HAS_DO_UPGRADE) {
098                        HttpServletRequest servletRequest = getHttpServletRequest(request);
099                        HttpServletResponse servletResponse = getHttpServletResponse(response);
100
101                        StringBuffer requestUrl = servletRequest.getRequestURL();
102                        String path = servletRequest.getRequestURI();  // shouldn't matter
103                        Map<String, String> pathParams = Collections.<String, String>emptyMap();
104
105                        ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(path, endpoint);
106                        endpointConfig.setSubprotocols(Collections.singletonList(selectedProtocol));
107                        endpointConfig.setExtensions(selectedExtensions);
108
109                        try {
110                                getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, endpointConfig, pathParams);
111                        }
112                        catch (ServletException ex) {
113                                throw new HandshakeFailureException(
114                                                "Servlet request failed to upgrade to WebSocket: " + requestUrl, ex);
115                        }
116                        catch (IOException ex) {
117                                throw new HandshakeFailureException(
118                                                "Response update failed during upgrade to WebSocket: " + requestUrl, ex);
119                        }
120                }
121                else {
122                        FALLBACK_STRATEGY.upgradeInternal(request, response, selectedProtocol, selectedExtensions, endpoint);
123                }
124        }
125
126        @Override
127        public ServerWebSocketContainer getContainer(HttpServletRequest request) {
128                return (ServerWebSocketContainer) super.getContainer(request);
129        }
130
131
132        /**
133         * Strategy for use with Undertow 1.0 to 1.3 before there was a public API
134         * to perform a WebSocket upgrade.
135         */
136        private static class FallbackStrategy extends AbstractStandardUpgradeStrategy {
137
138                private static final Constructor<ServletWebSocketHttpExchange> exchangeConstructor;
139
140                private static final boolean exchangeConstructorWithPeerConnections;
141
142                private static final Constructor<ConfiguredServerEndpoint> endpointConstructor;
143
144                private static final boolean endpointConstructorWithEndpointFactory;
145
146                private static final Method getBufferPoolMethod;
147
148                private static final Method createChannelMethod;
149
150                static {
151                        try {
152                                Class<ServletWebSocketHttpExchange> exchangeType = ServletWebSocketHttpExchange.class;
153                                Class<?>[] exchangeParamTypes =
154                                                new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class, Set.class};
155                                Constructor<ServletWebSocketHttpExchange> exchangeCtor =
156                                                ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes);
157                                if (exchangeCtor != null) {
158                                        // Undertow 1.1+
159                                        exchangeConstructor = exchangeCtor;
160                                        exchangeConstructorWithPeerConnections = true;
161                                }
162                                else {
163                                        // Undertow 1.0
164                                        exchangeParamTypes = new Class<?>[] {HttpServletRequest.class, HttpServletResponse.class};
165                                        exchangeConstructor = exchangeType.getConstructor(exchangeParamTypes);
166                                        exchangeConstructorWithPeerConnections = false;
167                                }
168
169                                Class<ConfiguredServerEndpoint> endpointType = ConfiguredServerEndpoint.class;
170                                Class<?>[] endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class,
171                                                PathTemplate.class, EncodingFactory.class, AnnotatedEndpointFactory.class};
172                                Constructor<ConfiguredServerEndpoint> endpointCtor =
173                                                ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes);
174                                if (endpointCtor != null) {
175                                        // Undertow 1.1+
176                                        endpointConstructor = endpointCtor;
177                                        endpointConstructorWithEndpointFactory = true;
178                                }
179                                else {
180                                        // Undertow 1.0
181                                        endpointParamTypes = new Class<?>[] {ServerEndpointConfig.class, InstanceFactory.class,
182                                                        PathTemplate.class, EncodingFactory.class};
183                                        endpointConstructor = endpointType.getConstructor(endpointParamTypes);
184                                        endpointConstructorWithEndpointFactory = false;
185                                }
186
187                                // Adapting between different Pool API types in Undertow 1.0-1.2 vs 1.3
188                                getBufferPoolMethod = WebSocketHttpExchange.class.getMethod("getBufferPool");
189                                createChannelMethod = ReflectionUtils.findMethod(Handshake.class, "createChannel", (Class<?>[]) null);
190                        }
191                        catch (Throwable ex) {
192                                throw new IllegalStateException("Incompatible Undertow API version", ex);
193                        }
194                }
195
196                private final Set<WebSocketChannel> peerConnections;
197
198                public FallbackStrategy() {
199                        if (exchangeConstructorWithPeerConnections) {
200                                this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>());
201                        }
202                        else {
203                                this.peerConnections = null;
204                        }
205                }
206
207                @Override
208                public String[] getSupportedVersions() {
209                        return VERSIONS;
210                }
211
212                @Override
213                protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
214                                String selectedProtocol, List<Extension> selectedExtensions, final Endpoint endpoint)
215                                throws HandshakeFailureException {
216
217                        HttpServletRequest servletRequest = getHttpServletRequest(request);
218                        HttpServletResponse servletResponse = getHttpServletResponse(response);
219
220                        final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse);
221                        exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap());
222
223                        ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest);
224                        final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer);
225
226                        final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint(
227                                        selectedProtocol, selectedExtensions, endpoint, servletRequest);
228
229                        final Handshake handshake = getHandshakeToUse(exchange, configuredServerEndpoint);
230
231                        exchange.upgradeChannel(new HttpUpgradeListener() {
232                                @Override
233                                public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) {
234                                        Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange);
235                                        WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod(
236                                                        createChannelMethod, handshake, exchange, connection, bufferPool);
237                                        if (peerConnections != null) {
238                                                peerConnections.add(channel);
239                                        }
240                                        endpointSessionHandler.onConnect(exchange, channel);
241                                }
242                        });
243
244                        handshake.handshake(exchange);
245                }
246
247                private ServletWebSocketHttpExchange createHttpExchange(HttpServletRequest request, HttpServletResponse response) {
248                        try {
249                                return (this.peerConnections != null ?
250                                                exchangeConstructor.newInstance(request, response, this.peerConnections) :
251                                                exchangeConstructor.newInstance(request, response));
252                        }
253                        catch (Exception ex) {
254                                throw new HandshakeFailureException("Failed to instantiate ServletWebSocketHttpExchange", ex);
255                        }
256                }
257
258                private Handshake getHandshakeToUse(ServletWebSocketHttpExchange exchange, ConfiguredServerEndpoint endpoint) {
259                        Handshake handshake = new JsrHybi13Handshake(endpoint);
260                        if (handshake.matches(exchange)) {
261                                return handshake;
262                        }
263                        handshake = new JsrHybi08Handshake(endpoint);
264                        if (handshake.matches(exchange)) {
265                                return handshake;
266                        }
267                        handshake = new JsrHybi07Handshake(endpoint);
268                        if (handshake.matches(exchange)) {
269                                return handshake;
270                        }
271                        // Should never occur
272                        throw new HandshakeFailureException("No matching Undertow Handshake found: " + exchange.getRequestHeaders());
273                }
274
275                private ConfiguredServerEndpoint createConfiguredServerEndpoint(String selectedProtocol,
276                                List<Extension> selectedExtensions, Endpoint endpoint, HttpServletRequest servletRequest) {
277
278                        String path = servletRequest.getRequestURI();  // shouldn't matter
279                        ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration(path, endpoint);
280                        endpointRegistration.setSubprotocols(Collections.singletonList(selectedProtocol));
281                        endpointRegistration.setExtensions(selectedExtensions);
282
283                        EncodingFactory encodingFactory = new EncodingFactory(
284                                        Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
285                                        Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap(),
286                                        Collections.<Class<?>, List<InstanceFactory<? extends Encoder>>>emptyMap(),
287                                        Collections.<Class<?>, List<InstanceFactory<? extends Decoder>>>emptyMap());
288                        try {
289                                return (endpointConstructorWithEndpointFactory ?
290                                                endpointConstructor.newInstance(endpointRegistration,
291                                                                new EndpointInstanceFactory(endpoint), null, encodingFactory, null) :
292                                                endpointConstructor.newInstance(endpointRegistration,
293                                                                new EndpointInstanceFactory(endpoint), null, encodingFactory));
294                        }
295                        catch (Exception ex) {
296                                throw new HandshakeFailureException("Failed to instantiate ConfiguredServerEndpoint", ex);
297                        }
298                }
299
300
301                private static class EndpointInstanceFactory implements InstanceFactory<Endpoint> {
302
303                        private final Endpoint endpoint;
304
305                        public EndpointInstanceFactory(Endpoint endpoint) {
306                                this.endpoint = endpoint;
307                        }
308
309                        @Override
310                        public InstanceHandle<Endpoint> createInstance() throws InstantiationException {
311                                return new InstanceHandle<Endpoint>() {
312                                        @Override
313                                        public Endpoint getInstance() {
314                                                return endpoint;
315                                        }
316                                        @Override
317                                        public void release() {
318                                        }
319                                };
320                        }
321                }
322        }
323
324}