001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.server.upgrade; 018 019import java.util.function.Supplier; 020 021import javax.servlet.ServletContext; 022import javax.servlet.http.HttpServletRequest; 023import javax.servlet.http.HttpServletResponse; 024 025import org.eclipse.jetty.websocket.api.WebSocketPolicy; 026import org.eclipse.jetty.websocket.server.WebSocketServerFactory; 027import reactor.core.publisher.Mono; 028 029import org.springframework.context.Lifecycle; 030import org.springframework.core.NamedThreadLocal; 031import org.springframework.core.io.buffer.DataBufferFactory; 032import org.springframework.http.server.reactive.AbstractServerHttpRequest; 033import org.springframework.http.server.reactive.AbstractServerHttpResponse; 034import org.springframework.http.server.reactive.ServerHttpRequest; 035import org.springframework.http.server.reactive.ServerHttpRequestDecorator; 036import org.springframework.http.server.reactive.ServerHttpResponse; 037import org.springframework.http.server.reactive.ServerHttpResponseDecorator; 038import org.springframework.lang.Nullable; 039import org.springframework.util.Assert; 040import org.springframework.web.reactive.socket.HandshakeInfo; 041import org.springframework.web.reactive.socket.WebSocketHandler; 042import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; 043import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; 044import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; 045import org.springframework.web.server.ServerWebExchange; 046 047/** 048 * A {@link RequestUpgradeStrategy} for use with Jetty. 049 * 050 * @author Violeta Georgieva 051 * @author Rossen Stoyanchev 052 * @since 5.0 053 */ 054public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle { 055 056 private static final ThreadLocal<WebSocketHandlerContainer> adapterHolder = 057 new NamedThreadLocal<>("JettyWebSocketHandlerAdapter"); 058 059 060 @Nullable 061 private WebSocketPolicy webSocketPolicy; 062 063 @Nullable 064 private WebSocketServerFactory factory; 065 066 @Nullable 067 private volatile ServletContext servletContext; 068 069 private volatile boolean running = false; 070 071 private final Object lifecycleMonitor = new Object(); 072 073 074 /** 075 * Configure a {@link WebSocketPolicy} to use to initialize 076 * {@link WebSocketServerFactory}. 077 * @param webSocketPolicy the WebSocket settings 078 */ 079 public void setWebSocketPolicy(WebSocketPolicy webSocketPolicy) { 080 this.webSocketPolicy = webSocketPolicy; 081 } 082 083 /** 084 * Return the configured {@link WebSocketPolicy}, if any. 085 */ 086 @Nullable 087 public WebSocketPolicy getWebSocketPolicy() { 088 return this.webSocketPolicy; 089 } 090 091 092 @Override 093 public void start() { 094 synchronized (this.lifecycleMonitor) { 095 ServletContext servletContext = this.servletContext; 096 if (!isRunning() && servletContext != null) { 097 try { 098 this.factory = (this.webSocketPolicy != null ? 099 new WebSocketServerFactory(servletContext, this.webSocketPolicy) : 100 new WebSocketServerFactory(servletContext)); 101 this.factory.setCreator((request, response) -> { 102 WebSocketHandlerContainer container = adapterHolder.get(); 103 String protocol = container.getProtocol(); 104 if (protocol != null) { 105 response.setAcceptedSubProtocol(protocol); 106 } 107 return container.getAdapter(); 108 }); 109 this.factory.start(); 110 this.running = true; 111 } 112 catch (Throwable ex) { 113 throw new IllegalStateException("Unable to start WebSocketServerFactory", ex); 114 } 115 } 116 } 117 } 118 119 @Override 120 public void stop() { 121 synchronized (this.lifecycleMonitor) { 122 if (isRunning()) { 123 if (this.factory != null) { 124 try { 125 this.factory.stop(); 126 this.running = false; 127 } 128 catch (Throwable ex) { 129 throw new IllegalStateException("Failed to stop WebSocketServerFactory", ex); 130 } 131 } 132 } 133 } 134 } 135 136 @Override 137 public boolean isRunning() { 138 return this.running; 139 } 140 141 142 @Override 143 public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler, 144 @Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) { 145 146 ServerHttpRequest request = exchange.getRequest(); 147 ServerHttpResponse response = exchange.getResponse(); 148 149 HttpServletRequest servletRequest = getNativeRequest(request); 150 HttpServletResponse servletResponse = getNativeResponse(response); 151 152 HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); 153 DataBufferFactory factory = response.bufferFactory(); 154 155 JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter( 156 handler, session -> new JettyWebSocketSession(session, handshakeInfo, factory)); 157 158 startLazily(servletRequest); 159 160 Assert.state(this.factory != null, "No WebSocketServerFactory available"); 161 boolean isUpgrade = this.factory.isUpgradeRequest(servletRequest, servletResponse); 162 Assert.isTrue(isUpgrade, "Not a WebSocket handshake"); 163 164 // Trigger WebFlux preCommit actions and upgrade 165 return exchange.getResponse().setComplete() 166 .then(Mono.fromCallable(() -> { 167 try { 168 adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol)); 169 this.factory.acceptWebSocket(servletRequest, servletResponse); 170 } 171 finally { 172 adapterHolder.remove(); 173 } 174 return null; 175 })); 176 } 177 178 private static HttpServletRequest getNativeRequest(ServerHttpRequest request) { 179 if (request instanceof AbstractServerHttpRequest) { 180 return ((AbstractServerHttpRequest) request).getNativeRequest(); 181 } 182 else if (request instanceof ServerHttpRequestDecorator) { 183 return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate()); 184 } 185 else { 186 throw new IllegalArgumentException( 187 "Couldn't find HttpServletRequest in " + request.getClass().getName()); 188 } 189 } 190 191 private static HttpServletResponse getNativeResponse(ServerHttpResponse response) { 192 if (response instanceof AbstractServerHttpResponse) { 193 return ((AbstractServerHttpResponse) response).getNativeResponse(); 194 } 195 else if (response instanceof ServerHttpResponseDecorator) { 196 return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate()); 197 } 198 else { 199 throw new IllegalArgumentException( 200 "Couldn't find HttpServletResponse in " + response.getClass().getName()); 201 } 202 } 203 204 private void startLazily(HttpServletRequest request) { 205 if (isRunning()) { 206 return; 207 } 208 synchronized (this.lifecycleMonitor) { 209 if (!isRunning()) { 210 this.servletContext = request.getServletContext(); 211 start(); 212 } 213 } 214 } 215 216 217 private static class WebSocketHandlerContainer { 218 219 private final JettyWebSocketHandlerAdapter adapter; 220 221 @Nullable 222 private final String protocol; 223 224 public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter adapter, @Nullable String protocol) { 225 this.adapter = adapter; 226 this.protocol = protocol; 227 } 228 229 public JettyWebSocketHandlerAdapter getAdapter() { 230 return this.adapter; 231 } 232 233 @Nullable 234 public String getProtocol() { 235 return this.protocol; 236 } 237 } 238 239}