001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.support;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024import javax.servlet.ServletContext;
025import javax.servlet.ServletException;
026import javax.servlet.http.HttpServletRequest;
027import javax.servlet.http.HttpServletResponse;
028
029import org.apache.commons.logging.Log;
030import org.apache.commons.logging.LogFactory;
031
032import org.springframework.context.Lifecycle;
033import org.springframework.http.server.ServerHttpRequest;
034import org.springframework.http.server.ServerHttpResponse;
035import org.springframework.http.server.ServletServerHttpRequest;
036import org.springframework.http.server.ServletServerHttpResponse;
037import org.springframework.util.Assert;
038import org.springframework.web.HttpRequestHandler;
039import org.springframework.web.context.ServletContextAware;
040import org.springframework.web.socket.WebSocketHandler;
041import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
042import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator;
043import org.springframework.web.socket.server.HandshakeFailureException;
044import org.springframework.web.socket.server.HandshakeHandler;
045import org.springframework.web.socket.server.HandshakeInterceptor;
046
047/**
048 * A {@link HttpRequestHandler} for processing WebSocket handshake requests.
049 *
050 * <p>This is the main class to use when configuring a server WebSocket at a specific URL.
051 * It is a very thin wrapper around a {@link WebSocketHandler} and a {@link HandshakeHandler},
052 * also adapting the {@link HttpServletRequest} and {@link HttpServletResponse} to
053 * {@link ServerHttpRequest} and {@link ServerHttpResponse}, respectively.
054 *
055 * @author Rossen Stoyanchev
056 * @since 4.0
057 */
058public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle, ServletContextAware {
059
060        private final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class);
061
062        private final WebSocketHandler wsHandler;
063
064        private final HandshakeHandler handshakeHandler;
065
066        private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
067
068        private volatile boolean running = false;
069
070
071        public WebSocketHttpRequestHandler(WebSocketHandler wsHandler) {
072                this(wsHandler, new DefaultHandshakeHandler());
073        }
074
075        public WebSocketHttpRequestHandler(WebSocketHandler wsHandler, HandshakeHandler handshakeHandler) {
076                Assert.notNull(wsHandler, "wsHandler must not be null");
077                Assert.notNull(handshakeHandler, "handshakeHandler must not be null");
078                this.wsHandler = new ExceptionWebSocketHandlerDecorator(new LoggingWebSocketHandlerDecorator(wsHandler));
079                this.handshakeHandler = handshakeHandler;
080        }
081
082
083        /**
084         * Return the WebSocketHandler.
085         */
086        public WebSocketHandler getWebSocketHandler() {
087                return this.wsHandler;
088        }
089
090        /**
091         * Return the HandshakeHandler.
092         */
093        public HandshakeHandler getHandshakeHandler() {
094                return this.handshakeHandler;
095        }
096
097        /**
098         * Configure one or more WebSocket handshake request interceptors.
099         */
100        public void setHandshakeInterceptors(List<HandshakeInterceptor> interceptors) {
101                this.interceptors.clear();
102                if (interceptors != null) {
103                        this.interceptors.addAll(interceptors);
104                }
105        }
106
107        /**
108         * Return the configured WebSocket handshake request interceptors.
109         */
110        public List<HandshakeInterceptor> getHandshakeInterceptors() {
111                return this.interceptors;
112        }
113
114        @Override
115        public void setServletContext(ServletContext servletContext) {
116                if (this.handshakeHandler instanceof ServletContextAware) {
117                        ((ServletContextAware) this.handshakeHandler).setServletContext(servletContext);
118                }
119        }
120
121
122        @Override
123        public void start() {
124                if (!isRunning()) {
125                        this.running = true;
126                        if (this.handshakeHandler instanceof Lifecycle) {
127                                ((Lifecycle) this.handshakeHandler).start();
128                        }
129                }
130        }
131
132        @Override
133        public void stop() {
134                if (isRunning()) {
135                        this.running = false;
136                        if (this.handshakeHandler instanceof Lifecycle) {
137                                ((Lifecycle) this.handshakeHandler).stop();
138                        }
139                }
140        }
141
142        @Override
143        public boolean isRunning() {
144                return this.running;
145        }
146
147
148        @Override
149        public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
150                        throws ServletException, IOException {
151
152                ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
153                ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);
154
155                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.wsHandler);
156                HandshakeFailureException failure = null;
157
158                try {
159                        if (logger.isDebugEnabled()) {
160                                logger.debug(servletRequest.getMethod() + " " + servletRequest.getRequestURI());
161                        }
162                        Map<String, Object> attributes = new HashMap<String, Object>();
163                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
164                                return;
165                        }
166                        this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
167                        chain.applyAfterHandshake(request, response, null);
168                        response.close();
169                }
170                catch (HandshakeFailureException ex) {
171                        failure = ex;
172                }
173                catch (Throwable ex) {
174                        failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
175                }
176                finally {
177                        if (failure != null) {
178                                chain.applyAfterHandshake(request, response, failure);
179                                throw failure;
180                        }
181                }
182        }
183
184}