001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.support;
018
019import java.io.IOException;
020import java.nio.charset.StandardCharsets;
021import java.security.Principal;
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.Collections;
025import java.util.List;
026import java.util.Map;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.context.Lifecycle;
032import org.springframework.http.HttpMethod;
033import org.springframework.http.HttpStatus;
034import org.springframework.http.server.ServerHttpRequest;
035import org.springframework.http.server.ServerHttpResponse;
036import org.springframework.lang.Nullable;
037import org.springframework.util.Assert;
038import org.springframework.util.ClassUtils;
039import org.springframework.util.ReflectionUtils;
040import org.springframework.util.StringUtils;
041import org.springframework.web.socket.SubProtocolCapable;
042import org.springframework.web.socket.WebSocketExtension;
043import org.springframework.web.socket.WebSocketHandler;
044import org.springframework.web.socket.WebSocketHttpHeaders;
045import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
046import org.springframework.web.socket.server.HandshakeFailureException;
047import org.springframework.web.socket.server.HandshakeHandler;
048import org.springframework.web.socket.server.RequestUpgradeStrategy;
049
050/**
051 * A base class for {@link HandshakeHandler} implementations, independent from the Servlet API.
052 *
053 * <p>Performs initial validation of the WebSocket handshake request - possibly rejecting it
054 * through the appropriate HTTP status code - while also allowing its subclasses to override
055 * various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation,
056 * extensions negotiation, etc).
057 *
058 * <p>If the negotiation succeeds, the actual upgrade is delegated to a server-specific
059 * {@link org.springframework.web.socket.server.RequestUpgradeStrategy}, which will update
060 * the response as necessary and initialize the WebSocket. Currently supported servers are
061 * Jetty 9.0-9.3, Tomcat 7.0.47+ and 8.x, Undertow 1.0-1.3, GlassFish 4.1+, WebLogic 12.1.3+.
062 *
063 * @author Rossen Stoyanchev
064 * @author Juergen Hoeller
065 * @since 4.2
066 * @see org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy
067 * @see org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy
068 * @see org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy
069 * @see org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy
070 * @see org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy
071 */
072public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {
073
074        private static final boolean jettyWsPresent;
075
076        private static final boolean tomcatWsPresent;
077
078        private static final boolean undertowWsPresent;
079
080        private static final boolean glassfishWsPresent;
081
082        private static final boolean weblogicWsPresent;
083
084        private static final boolean websphereWsPresent;
085
086        static {
087                ClassLoader classLoader = AbstractHandshakeHandler.class.getClassLoader();
088                jettyWsPresent = ClassUtils.isPresent(
089                                "org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader);
090                tomcatWsPresent = ClassUtils.isPresent(
091                                "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader);
092                undertowWsPresent = ClassUtils.isPresent(
093                                "io.undertow.websockets.jsr.ServerWebSocketContainer", classLoader);
094                glassfishWsPresent = ClassUtils.isPresent(
095                                "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", classLoader);
096                weblogicWsPresent = ClassUtils.isPresent(
097                                "weblogic.websocket.tyrus.TyrusServletWriter", classLoader);
098                websphereWsPresent = ClassUtils.isPresent(
099                                "com.ibm.websphere.wsoc.WsWsocServerContainer", classLoader);
100
101        }
102
103
104        protected final Log logger = LogFactory.getLog(getClass());
105
106        private final RequestUpgradeStrategy requestUpgradeStrategy;
107
108        private final List<String> supportedProtocols = new ArrayList<>();
109
110        private volatile boolean running = false;
111
112
113        /**
114         * Default constructor that auto-detects and instantiates a
115         * {@link RequestUpgradeStrategy} suitable for the runtime container.
116         * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found.
117         */
118        protected AbstractHandshakeHandler() {
119                this(initRequestUpgradeStrategy());
120        }
121
122        /**
123         * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}.
124         * @param requestUpgradeStrategy the upgrade strategy to use
125         */
126        protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) {
127                Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null");
128                this.requestUpgradeStrategy = requestUpgradeStrategy;
129        }
130
131
132        private static RequestUpgradeStrategy initRequestUpgradeStrategy() {
133                String className;
134                if (tomcatWsPresent) {
135                        className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy";
136                }
137                else if (jettyWsPresent) {
138                        className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy";
139                }
140                else if (undertowWsPresent) {
141                        className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy";
142                }
143                else if (glassfishWsPresent) {
144                        className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy";
145                }
146                else if (weblogicWsPresent) {
147                        className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy";
148                }
149                else if (websphereWsPresent) {
150                        className = "org.springframework.web.socket.server.standard.WebSphereRequestUpgradeStrategy";
151                }
152                else {
153                        throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
154                }
155
156                try {
157                        Class<?> clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader());
158                        return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance();
159                }
160                catch (Exception ex) {
161                        throw new IllegalStateException(
162                                        "Failed to instantiate RequestUpgradeStrategy: " + className, ex);
163                }
164        }
165
166
167        /**
168         * Return the {@link RequestUpgradeStrategy} for WebSocket requests.
169         */
170        public RequestUpgradeStrategy getRequestUpgradeStrategy() {
171                return this.requestUpgradeStrategy;
172        }
173
174        /**
175         * Use this property to configure the list of supported sub-protocols.
176         * The first configured sub-protocol that matches a client-requested sub-protocol
177         * is accepted. If there are no matches the response will not contain a
178         * {@literal Sec-WebSocket-Protocol} header.
179         * <p>Note that if the WebSocketHandler passed in at runtime is an instance of
180         * {@link SubProtocolCapable} then there is not need to explicitly configure
181         * this property. That is certainly the case with the built-in STOMP over
182         * WebSocket support. Therefore this property should be configured explicitly
183         * only if the WebSocketHandler does not implement {@code SubProtocolCapable}.
184         */
185        public void setSupportedProtocols(String... protocols) {
186                this.supportedProtocols.clear();
187                for (String protocol : protocols) {
188                        this.supportedProtocols.add(protocol.toLowerCase());
189                }
190        }
191
192        /**
193         * Return the list of supported sub-protocols.
194         */
195        public String[] getSupportedProtocols() {
196                return StringUtils.toStringArray(this.supportedProtocols);
197        }
198
199
200        @Override
201        public void start() {
202                if (!isRunning()) {
203                        this.running = true;
204                        doStart();
205                }
206        }
207
208        protected void doStart() {
209                if (this.requestUpgradeStrategy instanceof Lifecycle) {
210                        ((Lifecycle) this.requestUpgradeStrategy).start();
211                }
212        }
213
214        @Override
215        public void stop() {
216                if (isRunning()) {
217                        this.running = false;
218                        doStop();
219                }
220        }
221
222        protected void doStop() {
223                if (this.requestUpgradeStrategy instanceof Lifecycle) {
224                        ((Lifecycle) this.requestUpgradeStrategy).stop();
225                }
226        }
227
228        @Override
229        public boolean isRunning() {
230                return this.running;
231        }
232
233
234        @Override
235        public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response,
236                        WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
237
238                WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders());
239                if (logger.isTraceEnabled()) {
240                        logger.trace("Processing request " + request.getURI() + " with headers=" + headers);
241                }
242                try {
243                        if (HttpMethod.GET != request.getMethod()) {
244                                response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
245                                response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET));
246                                if (logger.isErrorEnabled()) {
247                                        logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod());
248                                }
249                                return false;
250                        }
251                        if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
252                                handleInvalidUpgradeHeader(request, response);
253                                return false;
254                        }
255                        if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
256                                handleInvalidConnectHeader(request, response);
257                                return false;
258                        }
259                        if (!isWebSocketVersionSupported(headers)) {
260                                handleWebSocketVersionNotSupported(request, response);
261                                return false;
262                        }
263                        if (!isValidOrigin(request)) {
264                                response.setStatusCode(HttpStatus.FORBIDDEN);
265                                return false;
266                        }
267                        String wsKey = headers.getSecWebSocketKey();
268                        if (wsKey == null) {
269                                if (logger.isErrorEnabled()) {
270                                        logger.error("Missing \"Sec-WebSocket-Key\" header");
271                                }
272                                response.setStatusCode(HttpStatus.BAD_REQUEST);
273                                return false;
274                        }
275                }
276                catch (IOException ex) {
277                        throw new HandshakeFailureException(
278                                        "Response update failed during upgrade to WebSocket: " + request.getURI(), ex);
279                }
280
281                String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler);
282                List<WebSocketExtension> requested = headers.getSecWebSocketExtensions();
283                List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
284                List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported);
285                Principal user = determineUser(request, wsHandler, attributes);
286
287                if (logger.isTraceEnabled()) {
288                        logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
289                }
290                this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
291                return true;
292        }
293
294        protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
295                if (logger.isErrorEnabled()) {
296                        logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade());
297                }
298                response.setStatusCode(HttpStatus.BAD_REQUEST);
299                response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(StandardCharsets.UTF_8));
300        }
301
302        protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
303                if (logger.isErrorEnabled()) {
304                        logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection());
305                }
306                response.setStatusCode(HttpStatus.BAD_REQUEST);
307                response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(StandardCharsets.UTF_8));
308        }
309
310        protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) {
311                String version = httpHeaders.getSecWebSocketVersion();
312                String[] supportedVersions = getSupportedVersions();
313                for (String supportedVersion : supportedVersions) {
314                        if (supportedVersion.trim().equals(version)) {
315                                return true;
316                        }
317                }
318                return false;
319        }
320
321        protected String[] getSupportedVersions() {
322                return this.requestUpgradeStrategy.getSupportedVersions();
323        }
324
325        protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
326                if (logger.isErrorEnabled()) {
327                        String version = request.getHeaders().getFirst("Sec-WebSocket-Version");
328                        logger.error("Handshake failed due to unsupported WebSocket version: " + version +
329                                        ". Supported versions: " + Arrays.toString(getSupportedVersions()));
330                }
331                response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
332                response.getHeaders().set(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION,
333                                StringUtils.arrayToCommaDelimitedString(getSupportedVersions()));
334        }
335
336        /**
337         * Return whether the request {@code Origin} header value is valid or not.
338         * By default, all origins as considered as valid. Consider using an
339         * {@link OriginHandshakeInterceptor} for filtering origins if needed.
340         */
341        protected boolean isValidOrigin(ServerHttpRequest request) {
342                return true;
343        }
344
345        /**
346         * Perform the sub-protocol negotiation based on requested and supported sub-protocols.
347         * For the list of supported sub-protocols, this method first checks if the target
348         * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any
349         * sub-protocols have been explicitly configured with
350         * {@link #setSupportedProtocols(String...)}.
351         * @param requestedProtocols the requested sub-protocols
352         * @param webSocketHandler the WebSocketHandler that will be used
353         * @return the selected protocols or {@code null}
354         * @see #determineHandlerSupportedProtocols(WebSocketHandler)
355         */
356        @Nullable
357        protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) {
358                List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler);
359                for (String protocol : requestedProtocols) {
360                        if (handlerProtocols.contains(protocol.toLowerCase())) {
361                                return protocol;
362                        }
363                        if (this.supportedProtocols.contains(protocol.toLowerCase())) {
364                                return protocol;
365                        }
366                }
367                return null;
368        }
369
370        /**
371         * Determine the sub-protocols supported by the given WebSocketHandler by
372         * checking whether it is an instance of {@link SubProtocolCapable}.
373         * @param handler the handler to check
374         * @return a list of supported protocols, or an empty list if none available
375         */
376        protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) {
377                WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler);
378                List<String> subProtocols = null;
379                if (handlerToCheck instanceof SubProtocolCapable) {
380                        subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols();
381                }
382                return (subProtocols != null ? subProtocols : Collections.emptyList());
383        }
384
385        /**
386         * Filter the list of requested WebSocket extensions.
387         * <p>As of 4.1, the default implementation of this method filters the list to
388         * leave only extensions that are both requested and supported.
389         * @param request the current request
390         * @param requestedExtensions the list of extensions requested by the client
391         * @param supportedExtensions the list of extensions supported by the server
392         * @return the selected extensions or an empty list
393         */
394        protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request,
395                        List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) {
396
397                List<WebSocketExtension> result = new ArrayList<>(requestedExtensions.size());
398                for (WebSocketExtension extension : requestedExtensions) {
399                        if (supportedExtensions.contains(extension)) {
400                                result.add(extension);
401                        }
402                }
403                return result;
404        }
405
406        /**
407         * A method that can be used to associate a user with the WebSocket session
408         * in the process of being established. The default implementation calls
409         * {@link ServerHttpRequest#getPrincipal()}
410         * <p>Subclasses can provide custom logic for associating a user with a session,
411         * for example for assigning a name to anonymous users (i.e. not fully authenticated).
412         * @param request the handshake request
413         * @param wsHandler the WebSocket handler that will handle messages
414         * @param attributes handshake attributes to pass to the WebSocket session
415         * @return the user for the WebSocket session, or {@code null} if not available
416         */
417        @Nullable
418        protected Principal determineUser(
419                        ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) {
420
421                return request.getPrincipal();
422        }
423
424}