001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.jetty; 018 019import java.io.IOException; 020import java.security.Principal; 021import java.util.ArrayList; 022import java.util.List; 023import java.util.Map; 024import java.util.Set; 025import javax.servlet.ServletContext; 026import javax.servlet.http.HttpServletRequest; 027import javax.servlet.http.HttpServletResponse; 028 029import org.eclipse.jetty.websocket.api.WebSocketPolicy; 030import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; 031import org.eclipse.jetty.websocket.server.HandshakeRFC6455; 032import org.eclipse.jetty.websocket.server.WebSocketServerFactory; 033import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; 034import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; 035import org.eclipse.jetty.websocket.servlet.WebSocketCreator; 036 037import org.springframework.context.Lifecycle; 038import org.springframework.core.NamedThreadLocal; 039import org.springframework.http.server.ServerHttpRequest; 040import org.springframework.http.server.ServerHttpResponse; 041import org.springframework.http.server.ServletServerHttpRequest; 042import org.springframework.http.server.ServletServerHttpResponse; 043import org.springframework.util.Assert; 044import org.springframework.util.ClassUtils; 045import org.springframework.util.CollectionUtils; 046import org.springframework.web.context.ServletContextAware; 047import org.springframework.web.socket.WebSocketExtension; 048import org.springframework.web.socket.WebSocketHandler; 049import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter; 050import org.springframework.web.socket.adapter.jetty.JettyWebSocketSession; 051import org.springframework.web.socket.adapter.jetty.WebSocketToJettyExtensionConfigAdapter; 052import org.springframework.web.socket.server.HandshakeFailureException; 053import org.springframework.web.socket.server.RequestUpgradeStrategy; 054 055/** 056 * A {@link RequestUpgradeStrategy} for use with Jetty 9.1-9.4. Based on 057 * Jetty's internal {@code org.eclipse.jetty.websocket.server.WebSocketHandler} class. 058 * 059 * @author Phillip Webb 060 * @author Rossen Stoyanchev 061 * @author Brian Clozel 062 * @author Juergen Hoeller 063 * @since 4.0 064 */ 065public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, ServletContextAware, Lifecycle { 066 067 private static final ThreadLocal<WebSocketHandlerContainer> containerHolder = 068 new NamedThreadLocal<WebSocketHandlerContainer>("WebSocketHandlerContainer"); 069 070 071 // Configurable factory adapter due to Jetty 9.3.15+ API differences: 072 // using WebSocketServerFactory(ServletContext) as a version indicator 073 private final WebSocketServerFactoryAdapter factoryAdapter = 074 (ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class) ? 075 new ModernJettyWebSocketServerFactoryAdapter() : new LegacyJettyWebSocketServerFactoryAdapter()); 076 077 private ServletContext servletContext; 078 079 private volatile boolean running = false; 080 081 private volatile List<WebSocketExtension> supportedExtensions; 082 083 084 /** 085 * Default constructor that creates {@link WebSocketServerFactory} through 086 * its default constructor thus using a default {@link WebSocketPolicy}. 087 */ 088 public JettyRequestUpgradeStrategy() { 089 this.factoryAdapter.setPolicy(WebSocketPolicy.newServerPolicy()); 090 } 091 092 /** 093 * A constructor accepting a {@link WebSocketPolicy} to be used when 094 * creating the {@link WebSocketServerFactory} instance. 095 * @param policy the policy to use 096 * @since 4.3.5 097 */ 098 public JettyRequestUpgradeStrategy(WebSocketPolicy policy) { 099 Assert.notNull(policy, "WebSocketPolicy must not be null"); 100 this.factoryAdapter.setPolicy(policy); 101 } 102 103 /** 104 * A constructor accepting a {@link WebSocketServerFactory}. 105 * @param factory the pre-configured factory to use 106 */ 107 public JettyRequestUpgradeStrategy(WebSocketServerFactory factory) { 108 Assert.notNull(factory, "WebSocketServerFactory must not be null"); 109 this.factoryAdapter.setFactory(factory); 110 } 111 112 113 @Override 114 public void setServletContext(ServletContext servletContext) { 115 this.servletContext = servletContext; 116 } 117 118 @Override 119 public void start() { 120 if (!isRunning()) { 121 this.running = true; 122 try { 123 this.factoryAdapter.start(); 124 } 125 catch (Throwable ex) { 126 throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex); 127 } 128 } 129 } 130 131 @Override 132 public void stop() { 133 if (isRunning()) { 134 this.running = false; 135 try { 136 this.factoryAdapter.stop(); 137 } 138 catch (Throwable ex) { 139 throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); 140 } 141 } 142 } 143 144 @Override 145 public boolean isRunning() { 146 return this.running; 147 } 148 149 150 @Override 151 public String[] getSupportedVersions() { 152 return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; 153 } 154 155 @Override 156 public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) { 157 if (this.supportedExtensions == null) { 158 this.supportedExtensions = buildWebSocketExtensions(); 159 } 160 return this.supportedExtensions; 161 } 162 163 private List<WebSocketExtension> buildWebSocketExtensions() { 164 Set<String> names = this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames(); 165 List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(names.size()); 166 for (String name : names) { 167 result.add(new WebSocketExtension(name)); 168 } 169 return result; 170 } 171 172 @Override 173 public void upgrade(ServerHttpRequest request, ServerHttpResponse response, 174 String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user, 175 WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException { 176 177 Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required"); 178 HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); 179 180 Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required"); 181 HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); 182 183 Assert.isTrue(this.factoryAdapter.getFactory().isUpgradeRequest(servletRequest, servletResponse), 184 "Not a WebSocket handshake"); 185 186 JettyWebSocketSession session = new JettyWebSocketSession(attributes, user); 187 JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session); 188 189 WebSocketHandlerContainer container = 190 new WebSocketHandlerContainer(handlerAdapter, selectedProtocol, selectedExtensions); 191 192 try { 193 containerHolder.set(container); 194 this.factoryAdapter.getFactory().acceptWebSocket(servletRequest, servletResponse); 195 } 196 catch (IOException ex) { 197 throw new HandshakeFailureException( 198 "Response update failed during upgrade to WebSocket: " + request.getURI(), ex); 199 } 200 finally { 201 containerHolder.remove(); 202 } 203 } 204 205 206 private static class WebSocketHandlerContainer { 207 208 private final JettyWebSocketHandlerAdapter handler; 209 210 private final String selectedProtocol; 211 212 private final List<ExtensionConfig> extensionConfigs; 213 214 public WebSocketHandlerContainer( 215 JettyWebSocketHandlerAdapter handler, String protocol, List<WebSocketExtension> extensions) { 216 217 this.handler = handler; 218 this.selectedProtocol = protocol; 219 if (CollectionUtils.isEmpty(extensions)) { 220 this.extensionConfigs = new ArrayList<ExtensionConfig>(0); 221 } 222 else { 223 this.extensionConfigs = new ArrayList<ExtensionConfig>(extensions.size()); 224 for (WebSocketExtension extension : extensions) { 225 this.extensionConfigs.add(new WebSocketToJettyExtensionConfigAdapter(extension)); 226 } 227 } 228 } 229 230 public JettyWebSocketHandlerAdapter getHandler() { 231 return this.handler; 232 } 233 234 public String getSelectedProtocol() { 235 return this.selectedProtocol; 236 } 237 238 public List<ExtensionConfig> getExtensionConfigs() { 239 return this.extensionConfigs; 240 } 241 } 242 243 244 private static abstract class WebSocketServerFactoryAdapter { 245 246 private WebSocketPolicy policy; 247 248 private WebSocketServerFactory factory; 249 250 public void setPolicy(WebSocketPolicy policy) { 251 this.policy = policy; 252 } 253 254 public void setFactory(WebSocketServerFactory factory) { 255 this.factory = factory; 256 } 257 258 public WebSocketServerFactory getFactory() { 259 return this.factory; 260 } 261 262 public void start() throws Exception { 263 if (this.factory == null) { 264 this.factory = createFactory(this.policy); 265 } 266 this.factory.setCreator(new WebSocketCreator() { 267 @Override 268 public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) { 269 WebSocketHandlerContainer container = containerHolder.get(); 270 Assert.state(container != null, "Expected WebSocketHandlerContainer"); 271 response.setAcceptedSubProtocol(container.getSelectedProtocol()); 272 response.setExtensions(container.getExtensionConfigs()); 273 return container.getHandler(); 274 } 275 }); 276 startFactory(this.factory); 277 } 278 279 public void stop() throws Exception { 280 if (this.factory != null) { 281 stopFactory(this.factory); 282 } 283 } 284 285 protected abstract WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception; 286 287 protected abstract void startFactory(WebSocketServerFactory factory) throws Exception; 288 289 protected abstract void stopFactory(WebSocketServerFactory factory) throws Exception; 290 } 291 292 293 // Jetty 9.3.15+ 294 private class ModernJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { 295 296 @Override 297 protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { 298 return new WebSocketServerFactory(servletContext, policy); 299 } 300 301 @Override 302 protected void startFactory(WebSocketServerFactory factory) throws Exception { 303 factory.start(); 304 } 305 306 @Override 307 protected void stopFactory(WebSocketServerFactory factory) throws Exception { 308 factory.stop(); 309 } 310 } 311 312 313 // Jetty <9.3.15 314 private class LegacyJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { 315 316 @Override 317 protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { 318 return WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class).newInstance(policy); 319 } 320 321 @Override 322 protected void startFactory(WebSocketServerFactory factory) throws Exception { 323 try { 324 WebSocketServerFactory.class.getMethod("init", ServletContext.class).invoke(factory, servletContext); 325 } 326 catch (NoSuchMethodException ex) { 327 // Jetty 9.1/9.2 328 WebSocketServerFactory.class.getMethod("init").invoke(factory); 329 } 330 } 331 332 @Override 333 protected void stopFactory(WebSocketServerFactory factory) throws Exception { 334 WebSocketServerFactory.class.getMethod("cleanup").invoke(factory); 335 } 336 } 337 338}