001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.support; 018 019import java.io.IOException; 020import java.util.ArrayList; 021import java.util.HashMap; 022import java.util.List; 023import java.util.Map; 024 025import javax.servlet.ServletContext; 026import javax.servlet.ServletException; 027import javax.servlet.http.HttpServletRequest; 028import javax.servlet.http.HttpServletResponse; 029 030import org.apache.commons.logging.Log; 031import org.apache.commons.logging.LogFactory; 032 033import org.springframework.context.Lifecycle; 034import org.springframework.http.server.ServerHttpRequest; 035import org.springframework.http.server.ServerHttpResponse; 036import org.springframework.http.server.ServletServerHttpRequest; 037import org.springframework.http.server.ServletServerHttpResponse; 038import org.springframework.lang.Nullable; 039import org.springframework.util.Assert; 040import org.springframework.web.HttpRequestHandler; 041import org.springframework.web.context.ServletContextAware; 042import org.springframework.web.socket.WebSocketHandler; 043import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator; 044import org.springframework.web.socket.handler.LoggingWebSocketHandlerDecorator; 045import org.springframework.web.socket.server.HandshakeFailureException; 046import org.springframework.web.socket.server.HandshakeHandler; 047import org.springframework.web.socket.server.HandshakeInterceptor; 048 049/** 050 * A {@link HttpRequestHandler} for processing WebSocket handshake requests. 051 * 052 * <p>This is the main class to use when configuring a server WebSocket at a specific URL. 053 * It is a very thin wrapper around a {@link WebSocketHandler} and a {@link HandshakeHandler}, 054 * also adapting the {@link HttpServletRequest} and {@link HttpServletResponse} to 055 * {@link ServerHttpRequest} and {@link ServerHttpResponse}, respectively. 056 * 057 * @author Rossen Stoyanchev 058 * @since 4.0 059 */ 060public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycle, ServletContextAware { 061 062 private static final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class); 063 064 private final WebSocketHandler wsHandler; 065 066 private final HandshakeHandler handshakeHandler; 067 068 private final List<HandshakeInterceptor> interceptors = new ArrayList<>(); 069 070 private volatile boolean running = false; 071 072 073 public WebSocketHttpRequestHandler(WebSocketHandler wsHandler) { 074 this(wsHandler, new DefaultHandshakeHandler()); 075 } 076 077 public WebSocketHttpRequestHandler(WebSocketHandler wsHandler, HandshakeHandler handshakeHandler) { 078 Assert.notNull(wsHandler, "wsHandler must not be null"); 079 Assert.notNull(handshakeHandler, "handshakeHandler must not be null"); 080 this.wsHandler = decorate(wsHandler); 081 this.handshakeHandler = handshakeHandler; 082 } 083 084 /** 085 * Decorate the {@code WebSocketHandler} passed into the constructor. 086 * <p>By default, {@link LoggingWebSocketHandlerDecorator} and 087 * {@link ExceptionWebSocketHandlerDecorator} are added. 088 * @since 5.2.2 089 */ 090 protected WebSocketHandler decorate(WebSocketHandler handler) { 091 return new ExceptionWebSocketHandlerDecorator(new LoggingWebSocketHandlerDecorator(handler)); 092 } 093 094 095 /** 096 * Return the WebSocketHandler. 097 */ 098 public WebSocketHandler getWebSocketHandler() { 099 return this.wsHandler; 100 } 101 102 /** 103 * Return the HandshakeHandler. 104 */ 105 public HandshakeHandler getHandshakeHandler() { 106 return this.handshakeHandler; 107 } 108 109 /** 110 * Configure one or more WebSocket handshake request interceptors. 111 */ 112 public void setHandshakeInterceptors(@Nullable List<HandshakeInterceptor> interceptors) { 113 this.interceptors.clear(); 114 if (interceptors != null) { 115 this.interceptors.addAll(interceptors); 116 } 117 } 118 119 /** 120 * Return the configured WebSocket handshake request interceptors. 121 */ 122 public List<HandshakeInterceptor> getHandshakeInterceptors() { 123 return this.interceptors; 124 } 125 126 @Override 127 public void setServletContext(ServletContext servletContext) { 128 if (this.handshakeHandler instanceof ServletContextAware) { 129 ((ServletContextAware) this.handshakeHandler).setServletContext(servletContext); 130 } 131 } 132 133 134 @Override 135 public void start() { 136 if (!isRunning()) { 137 this.running = true; 138 if (this.handshakeHandler instanceof Lifecycle) { 139 ((Lifecycle) this.handshakeHandler).start(); 140 } 141 } 142 } 143 144 @Override 145 public void stop() { 146 if (isRunning()) { 147 this.running = false; 148 if (this.handshakeHandler instanceof Lifecycle) { 149 ((Lifecycle) this.handshakeHandler).stop(); 150 } 151 } 152 } 153 154 @Override 155 public boolean isRunning() { 156 return this.running; 157 } 158 159 160 @Override 161 public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse) 162 throws ServletException, IOException { 163 164 ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); 165 ServerHttpResponse response = new ServletServerHttpResponse(servletResponse); 166 167 HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.wsHandler); 168 HandshakeFailureException failure = null; 169 170 try { 171 if (logger.isDebugEnabled()) { 172 logger.debug(servletRequest.getMethod() + " " + servletRequest.getRequestURI()); 173 } 174 Map<String, Object> attributes = new HashMap<>(); 175 if (!chain.applyBeforeHandshake(request, response, attributes)) { 176 return; 177 } 178 this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes); 179 chain.applyAfterHandshake(request, response, null); 180 } 181 catch (HandshakeFailureException ex) { 182 failure = ex; 183 } 184 catch (Exception ex) { 185 failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex); 186 } 187 finally { 188 if (failure != null) { 189 chain.applyAfterHandshake(request, response, failure); 190 response.close(); 191 throw failure; 192 } 193 response.close(); 194 } 195 } 196 197}