001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.support;
018
019import java.io.IOException;
020import java.nio.charset.Charset;
021import java.security.Principal;
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.Collections;
025import java.util.List;
026import java.util.Map;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.context.Lifecycle;
032import org.springframework.http.HttpMethod;
033import org.springframework.http.HttpStatus;
034import org.springframework.http.server.ServerHttpRequest;
035import org.springframework.http.server.ServerHttpResponse;
036import org.springframework.util.Assert;
037import org.springframework.util.ClassUtils;
038import org.springframework.util.StringUtils;
039import org.springframework.web.socket.SubProtocolCapable;
040import org.springframework.web.socket.WebSocketExtension;
041import org.springframework.web.socket.WebSocketHandler;
042import org.springframework.web.socket.WebSocketHttpHeaders;
043import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
044import org.springframework.web.socket.server.HandshakeFailureException;
045import org.springframework.web.socket.server.HandshakeHandler;
046import org.springframework.web.socket.server.RequestUpgradeStrategy;
047
048/**
049 * A base class for {@link HandshakeHandler} implementations, independent from the Servlet API.
050 *
051 * <p>Performs initial validation of the WebSocket handshake request - possibly rejecting it
052 * through the appropriate HTTP status code - while also allowing its subclasses to override
053 * various parts of the negotiation process (e.g. origin validation, sub-protocol negotiation,
054 * extensions negotiation, etc).
055 *
056 * <p>If the negotiation succeeds, the actual upgrade is delegated to a server-specific
057 * {@link org.springframework.web.socket.server.RequestUpgradeStrategy}, which will update
058 * the response as necessary and initialize the WebSocket. Currently supported servers are
059 * Jetty 9.0-9.3, Tomcat 7.0.47+ and 8.x, Undertow 1.0-1.3, GlassFish 4.1+, WebLogic 12.1.3+.
060 *
061 * @author Rossen Stoyanchev
062 * @author Juergen Hoeller
063 * @since 4.2
064 * @see org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy
065 * @see org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy
066 * @see org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy
067 * @see org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy
068 * @see org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy
069 */
070public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {
071
072        private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
073
074
075        private static final boolean jettyWsPresent = ClassUtils.isPresent(
076                        "org.eclipse.jetty.websocket.server.WebSocketServerFactory", AbstractHandshakeHandler.class.getClassLoader());
077
078        private static final boolean tomcatWsPresent = ClassUtils.isPresent(
079                        "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader());
080
081        private static final boolean undertowWsPresent = ClassUtils.isPresent(
082                        "io.undertow.websockets.jsr.ServerWebSocketContainer", AbstractHandshakeHandler.class.getClassLoader());
083
084        private static final boolean glassfishWsPresent = ClassUtils.isPresent(
085                        "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", AbstractHandshakeHandler.class.getClassLoader());
086
087        private static final boolean weblogicWsPresent = ClassUtils.isPresent(
088                        "weblogic.websocket.tyrus.TyrusServletWriter", AbstractHandshakeHandler.class.getClassLoader());
089
090        private static final boolean websphereWsPresent = ClassUtils.isPresent(
091                        "com.ibm.websphere.wsoc.WsWsocServerContainer", AbstractHandshakeHandler.class.getClassLoader());
092
093
094        protected final Log logger = LogFactory.getLog(getClass());
095
096        private final RequestUpgradeStrategy requestUpgradeStrategy;
097
098        private final List<String> supportedProtocols = new ArrayList<String>();
099
100        private volatile boolean running = false;
101
102
103        /**
104         * Default constructor that auto-detects and instantiates a
105         * {@link RequestUpgradeStrategy} suitable for the runtime container.
106         * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found.
107         */
108        protected AbstractHandshakeHandler() {
109                this(initRequestUpgradeStrategy());
110        }
111
112        /**
113         * A constructor that accepts a runtime-specific {@link RequestUpgradeStrategy}.
114         * @param requestUpgradeStrategy the upgrade strategy to use
115         */
116        protected AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) {
117                Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy must not be null");
118                this.requestUpgradeStrategy = requestUpgradeStrategy;
119        }
120
121
122        private static RequestUpgradeStrategy initRequestUpgradeStrategy() {
123                String className;
124                if (tomcatWsPresent) {
125                        className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy";
126                }
127                else if (jettyWsPresent) {
128                        className = "org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy";
129                }
130                else if (undertowWsPresent) {
131                        className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy";
132                }
133                else if (glassfishWsPresent) {
134                        className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy";
135                }
136                else if (weblogicWsPresent) {
137                        className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy";
138                }
139                else if (websphereWsPresent) {
140                        className = "org.springframework.web.socket.server.standard.WebSphereRequestUpgradeStrategy";
141                }
142                else {
143                        throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
144                }
145
146                try {
147                        Class<?> clazz = ClassUtils.forName(className, AbstractHandshakeHandler.class.getClassLoader());
148                        return (RequestUpgradeStrategy) clazz.newInstance();
149                }
150                catch (Throwable ex) {
151                        throw new IllegalStateException(
152                                        "Failed to instantiate RequestUpgradeStrategy: " + className, ex);
153                }
154        }
155
156
157        /**
158         * Return the {@link RequestUpgradeStrategy} for WebSocket requests.
159         */
160        public RequestUpgradeStrategy getRequestUpgradeStrategy() {
161                return this.requestUpgradeStrategy;
162        }
163
164        /**
165         * Use this property to configure the list of supported sub-protocols.
166         * The first configured sub-protocol that matches a client-requested sub-protocol
167         * is accepted. If there are no matches the response will not contain a
168         * {@literal Sec-WebSocket-Protocol} header.
169         * <p>Note that if the WebSocketHandler passed in at runtime is an instance of
170         * {@link SubProtocolCapable} then there is not need to explicitly configure
171         * this property. That is certainly the case with the built-in STOMP over
172         * WebSocket support. Therefore this property should be configured explicitly
173         * only if the WebSocketHandler does not implement {@code SubProtocolCapable}.
174         */
175        public void setSupportedProtocols(String... protocols) {
176                this.supportedProtocols.clear();
177                for (String protocol : protocols) {
178                        this.supportedProtocols.add(protocol.toLowerCase());
179                }
180        }
181
182        /**
183         * Return the list of supported sub-protocols.
184         */
185        public String[] getSupportedProtocols() {
186                return StringUtils.toStringArray(this.supportedProtocols);
187        }
188
189
190        @Override
191        public void start() {
192                if (!isRunning()) {
193                        this.running = true;
194                        doStart();
195                }
196        }
197
198        protected void doStart() {
199                if (this.requestUpgradeStrategy instanceof Lifecycle) {
200                        ((Lifecycle) this.requestUpgradeStrategy).start();
201                }
202        }
203
204        @Override
205        public void stop() {
206                if (isRunning()) {
207                        this.running = false;
208                        doStop();
209                }
210        }
211
212        protected void doStop() {
213                if (this.requestUpgradeStrategy instanceof Lifecycle) {
214                        ((Lifecycle) this.requestUpgradeStrategy).stop();
215                }
216        }
217
218        @Override
219        public boolean isRunning() {
220                return this.running;
221        }
222
223
224        @Override
225        public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response,
226                        WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
227
228                WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHeaders());
229                if (logger.isTraceEnabled()) {
230                        logger.trace("Processing request " + request.getURI() + " with headers=" + headers);
231                }
232                try {
233                        if (HttpMethod.GET != request.getMethod()) {
234                                response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
235                                response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET));
236                                if (logger.isErrorEnabled()) {
237                                        logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod());
238                                }
239                                return false;
240                        }
241                        if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
242                                handleInvalidUpgradeHeader(request, response);
243                                return false;
244                        }
245                        if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
246                                handleInvalidConnectHeader(request, response);
247                                return false;
248                        }
249                        if (!isWebSocketVersionSupported(headers)) {
250                                handleWebSocketVersionNotSupported(request, response);
251                                return false;
252                        }
253                        if (!isValidOrigin(request)) {
254                                response.setStatusCode(HttpStatus.FORBIDDEN);
255                                return false;
256                        }
257                        String wsKey = headers.getSecWebSocketKey();
258                        if (wsKey == null) {
259                                if (logger.isErrorEnabled()) {
260                                        logger.error("Missing \"Sec-WebSocket-Key\" header");
261                                }
262                                response.setStatusCode(HttpStatus.BAD_REQUEST);
263                                return false;
264                        }
265                }
266                catch (IOException ex) {
267                        throw new HandshakeFailureException(
268                                        "Response update failed during upgrade to WebSocket: " + request.getURI(), ex);
269                }
270
271                String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler);
272                List<WebSocketExtension> requested = headers.getSecWebSocketExtensions();
273                List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
274                List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported);
275                Principal user = determineUser(request, wsHandler, attributes);
276
277                if (logger.isTraceEnabled()) {
278                        logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
279                }
280                this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
281                return true;
282        }
283
284        protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
285                if (logger.isErrorEnabled()) {
286                        logger.error("Handshake failed due to invalid Upgrade header: " + request.getHeaders().getUpgrade());
287                }
288                response.setStatusCode(HttpStatus.BAD_REQUEST);
289                response.getBody().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(UTF8_CHARSET));
290        }
291
292        protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
293                if (logger.isErrorEnabled()) {
294                        logger.error("Handshake failed due to invalid Connection header " + request.getHeaders().getConnection());
295                }
296                response.setStatusCode(HttpStatus.BAD_REQUEST);
297                response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes(UTF8_CHARSET));
298        }
299
300        protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders httpHeaders) {
301                String version = httpHeaders.getSecWebSocketVersion();
302                String[] supportedVersions = getSupportedVersions();
303                for (String supportedVersion : supportedVersions) {
304                        if (supportedVersion.trim().equals(version)) {
305                                return true;
306                        }
307                }
308                return false;
309        }
310
311        protected String[] getSupportedVersions() {
312                return this.requestUpgradeStrategy.getSupportedVersions();
313        }
314
315        protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
316                if (logger.isErrorEnabled()) {
317                        String version = request.getHeaders().getFirst("Sec-WebSocket-Version");
318                        logger.error("Handshake failed due to unsupported WebSocket version: " + version +
319                                        ". Supported versions: " + Arrays.toString(getSupportedVersions()));
320                }
321                response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
322                response.getHeaders().set(WebSocketHttpHeaders.SEC_WEBSOCKET_VERSION,
323                                StringUtils.arrayToCommaDelimitedString(getSupportedVersions()));
324        }
325
326        /**
327         * Return whether the request {@code Origin} header value is valid or not.
328         * By default, all origins as considered as valid. Consider using an
329         * {@link OriginHandshakeInterceptor} for filtering origins if needed.
330         */
331        protected boolean isValidOrigin(ServerHttpRequest request) {
332                return true;
333        }
334
335        /**
336         * Perform the sub-protocol negotiation based on requested and supported sub-protocols.
337         * For the list of supported sub-protocols, this method first checks if the target
338         * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any
339         * sub-protocols have been explicitly configured with
340         * {@link #setSupportedProtocols(String...)}.
341         * @param requestedProtocols the requested sub-protocols
342         * @param webSocketHandler the WebSocketHandler that will be used
343         * @return the selected protocols or {@code null}
344         * @see #determineHandlerSupportedProtocols(WebSocketHandler)
345         */
346        protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) {
347                if (requestedProtocols != null) {
348                        List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler);
349                        for (String protocol : requestedProtocols) {
350                                if (handlerProtocols.contains(protocol.toLowerCase())) {
351                                        return protocol;
352                                }
353                                if (this.supportedProtocols.contains(protocol.toLowerCase())) {
354                                        return protocol;
355                                }
356                        }
357                }
358                return null;
359        }
360
361        /**
362         * Determine the sub-protocols supported by the given WebSocketHandler by
363         * checking whether it is an instance of {@link SubProtocolCapable}.
364         * @param handler the handler to check
365         * @return a list of supported protocols, or an empty list if none available
366         */
367        protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) {
368                WebSocketHandler handlerToCheck = WebSocketHandlerDecorator.unwrap(handler);
369                List<String> subProtocols = null;
370                if (handlerToCheck instanceof SubProtocolCapable) {
371                        subProtocols = ((SubProtocolCapable) handlerToCheck).getSubProtocols();
372                }
373                return (subProtocols != null ? subProtocols : Collections.<String>emptyList());
374        }
375
376        /**
377         * Filter the list of requested WebSocket extensions.
378         * <p>As of 4.1, the default implementation of this method filters the list to
379         * leave only extensions that are both requested and supported.
380         * @param request the current request
381         * @param requestedExtensions the list of extensions requested by the client
382         * @param supportedExtensions the list of extensions supported by the server
383         * @return the selected extensions or an empty list
384         */
385        protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request,
386                        List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) {
387
388                List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(requestedExtensions.size());
389                for (WebSocketExtension extension : requestedExtensions) {
390                        if (supportedExtensions.contains(extension)) {
391                                result.add(extension);
392                        }
393                }
394                return result;
395        }
396
397        /**
398         * A method that can be used to associate a user with the WebSocket session
399         * in the process of being established. The default implementation calls
400         * {@link ServerHttpRequest#getPrincipal()}
401         * <p>Subclasses can provide custom logic for associating a user with a session,
402         * for example for assigning a name to anonymous users (i.e. not fully authenticated).
403         * @param request the handshake request
404         * @param wsHandler the WebSocket handler that will handle messages
405         * @param attributes handshake attributes to pass to the WebSocket session
406         * @return the user for the WebSocket session, or {@code null} if not available
407         */
408        protected Principal determineUser(
409                        ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) {
410
411                return request.getPrincipal();
412        }
413
414}