001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.support;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024
025import javax.servlet.ServletContext;
026import javax.servlet.ServletException;
027import javax.servlet.http.HttpServletRequest;
028import javax.servlet.http.HttpServletResponse;
029
030import org.apache.commons.logging.Log;
031import org.apache.commons.logging.LogFactory;
032
033import org.springframework.context.Lifecycle;
034import org.springframework.http.server.ServerHttpRequest;
035import org.springframework.http.server.ServerHttpResponse;
036import org.springframework.http.server.ServletServerHttpRequest;
037import org.springframework.http.server.ServletServerHttpResponse;
038import org.springframework.lang.Nullable;
039import org.springframework.util.Assert;
040import org.springframework.web.HttpRequestHandler;
041import org.springframework.web.context.ServletContextAware;
042import org.springframework.web.socket.WebSocketHandler;
043import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
044import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator;
045import org.springframework.web.socket.server.HandshakeFailureException;
046import org.springframework.web.socket.server.HandshakeHandler;
047import org.springframework.web.socket.server.HandshakeInterceptor;
048
049/**
050 * A {@link HttpRequestHandler} for processing WebSocket handshake requests.
051 *
052 * <p>This is the main class to use when configuring a server WebSocket at a specific URL.
053 * It is a very thin wrapper around a {@link WebSocketHandler} and a {@link HandshakeHandler},
054 * also adapting the {@link HttpServletRequest} and {@link HttpServletResponse} to
055 * {@link ServerHttpRequest} and {@link ServerHttpResponse}, respectively.
056 *
057 * @author Rossen Stoyanchev
058 * @since 4.0
059 */
060public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle, ServletContextAware {
061
062        private static final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class);
063
064        private final WebSocketHandler wsHandler;
065
066        private final HandshakeHandler handshakeHandler;
067
068        private final List<HandshakeInterceptor> interceptors = new ArrayList<>();
069
070        private volatile boolean running = false;
071
072
073        public WebSocketHttpRequestHandler(WebSocketHandler wsHandler) {
074                this(wsHandler, new DefaultHandshakeHandler());
075        }
076
077        public WebSocketHttpRequestHandler(WebSocketHandler wsHandler, HandshakeHandler handshakeHandler) {
078                Assert.notNull(wsHandler, "wsHandler must not be null");
079                Assert.notNull(handshakeHandler, "handshakeHandler must not be null");
080                this.wsHandler = decorate(wsHandler);
081                this.handshakeHandler = handshakeHandler;
082        }
083
084        /**
085         * Decorate the {@code WebSocketHandler} passed into the constructor.
086         * <p>By default, {@link LoggingWebSocketHandlerDecorator} and
087         * {@link ExceptionWebSocketHandlerDecorator} are added.
088         * @since 5.2.2
089         */
090        protected WebSocketHandler decorate(WebSocketHandler handler) {
091                return new ExceptionWebSocketHandlerDecorator(new LoggingWebSocketHandlerDecorator(handler));
092        }
093
094
095        /**
096         * Return the WebSocketHandler.
097         */
098        public WebSocketHandler getWebSocketHandler() {
099                return this.wsHandler;
100        }
101
102        /**
103         * Return the HandshakeHandler.
104         */
105        public HandshakeHandler getHandshakeHandler() {
106                return this.handshakeHandler;
107        }
108
109        /**
110         * Configure one or more WebSocket handshake request interceptors.
111         */
112        public void setHandshakeInterceptors(@Nullable List<HandshakeInterceptor> interceptors) {
113                this.interceptors.clear();
114                if (interceptors != null) {
115                        this.interceptors.addAll(interceptors);
116                }
117        }
118
119        /**
120         * Return the configured WebSocket handshake request interceptors.
121         */
122        public List<HandshakeInterceptor> getHandshakeInterceptors() {
123                return this.interceptors;
124        }
125
126        @Override
127        public void setServletContext(ServletContext servletContext) {
128                if (this.handshakeHandler instanceof ServletContextAware) {
129                        ((ServletContextAware) this.handshakeHandler).setServletContext(servletContext);
130                }
131        }
132
133
134        @Override
135        public void start() {
136                if (!isRunning()) {
137                        this.running = true;
138                        if (this.handshakeHandler instanceof Lifecycle) {
139                                ((Lifecycle) this.handshakeHandler).start();
140                        }
141                }
142        }
143
144        @Override
145        public void stop() {
146                if (isRunning()) {
147                        this.running = false;
148                        if (this.handshakeHandler instanceof Lifecycle) {
149                                ((Lifecycle) this.handshakeHandler).stop();
150                        }
151                }
152        }
153
154        @Override
155        public boolean isRunning() {
156                return this.running;
157        }
158
159
160        @Override
161        public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
162                        throws ServletException, IOException {
163
164                ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
165                ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);
166
167                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.wsHandler);
168                HandshakeFailureException failure = null;
169
170                try {
171                        if (logger.isDebugEnabled()) {
172                                logger.debug(servletRequest.getMethod() + " " + servletRequest.getRequestURI());
173                        }
174                        Map<String, Object> attributes = new HashMap<>();
175                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
176                                return;
177                        }
178                        this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
179                        chain.applyAfterHandshake(request, response, null);
180                }
181                catch (HandshakeFailureException ex) {
182                        failure = ex;
183                }
184                catch (Exception ex) {
185                        failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
186                }
187                finally {
188                        if (failure != null) {
189                                chain.applyAfterHandshake(request, response, failure);
190                                response.close();
191                                throw failure;
192                        }
193                        response.close();
194                }
195        }
196
197}