001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.server.upgrade;
018
019import java.util.function.Supplier;
020
021import javax.servlet.ServletContext;
022import javax.servlet.http.HttpServletRequest;
023import javax.servlet.http.HttpServletResponse;
024
025import org.eclipse.jetty.websocket.api.WebSocketPolicy;
026import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
027import reactor.core.publisher.Mono;
028
029import org.springframework.context.Lifecycle;
030import org.springframework.core.NamedThreadLocal;
031import org.springframework.core.io.buffer.DataBufferFactory;
032import org.springframework.http.server.reactive.AbstractServerHttpRequest;
033import org.springframework.http.server.reactive.AbstractServerHttpResponse;
034import org.springframework.http.server.reactive.ServerHttpRequest;
035import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
036import org.springframework.http.server.reactive.ServerHttpResponse;
037import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
038import org.springframework.lang.Nullable;
039import org.springframework.util.Assert;
040import org.springframework.web.reactive.socket.HandshakeInfo;
041import org.springframework.web.reactive.socket.WebSocketHandler;
042import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
043import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
044import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
045import org.springframework.web.server.ServerWebExchange;
046
047/**
048 * A {@link RequestUpgradeStrategy} for use with Jetty.
049 *
050 * @author Violeta Georgieva
051 * @author Rossen Stoyanchev
052 * @since 5.0
053 */
054public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle {
055
056        private static final ThreadLocal<WebSocketHandlerContainer> adapterHolder =
057                        new NamedThreadLocal<>("JettyWebSocketHandlerAdapter");
058
059
060        @Nullable
061        private WebSocketPolicy webSocketPolicy;
062
063        @Nullable
064        private WebSocketServerFactory factory;
065
066        @Nullable
067        private volatile ServletContext servletContext;
068
069        private volatile boolean running = false;
070
071        private final Object lifecycleMonitor = new Object();
072
073
074        /**
075         * Configure a {@link WebSocketPolicy} to use to initialize
076         * {@link WebSocketServerFactory}.
077         * @param webSocketPolicy the WebSocket settings
078         */
079        public void setWebSocketPolicy(WebSocketPolicy webSocketPolicy) {
080                this.webSocketPolicy = webSocketPolicy;
081        }
082
083        /**
084         * Return the configured {@link WebSocketPolicy}, if any.
085         */
086        @Nullable
087        public WebSocketPolicy getWebSocketPolicy() {
088                return this.webSocketPolicy;
089        }
090
091
092        @Override
093        public void start() {
094                synchronized (this.lifecycleMonitor) {
095                        ServletContext servletContext = this.servletContext;
096                        if (!isRunning() && servletContext != null) {
097                                try {
098                                        this.factory = (this.webSocketPolicy != null ?
099                                                        new WebSocketServerFactory(servletContext, this.webSocketPolicy) :
100                                                        new WebSocketServerFactory(servletContext));
101                                        this.factory.setCreator((request, response) -> {
102                                                WebSocketHandlerContainer container = adapterHolder.get();
103                                                String protocol = container.getProtocol();
104                                                if (protocol != null) {
105                                                        response.setAcceptedSubProtocol(protocol);
106                                                }
107                                                return container.getAdapter();
108                                        });
109                                        this.factory.start();
110                                        this.running = true;
111                                }
112                                catch (Throwable ex) {
113                                        throw new IllegalStateException("Unable to start WebSocketServerFactory", ex);
114                                }
115                        }
116                }
117        }
118
119        @Override
120        public void stop() {
121                synchronized (this.lifecycleMonitor) {
122                        if (isRunning()) {
123                                if (this.factory != null) {
124                                        try {
125                                                this.factory.stop();
126                                                this.running = false;
127                                        }
128                                        catch (Throwable ex) {
129                                                throw new IllegalStateException("Failed to stop WebSocketServerFactory", ex);
130                                        }
131                                }
132                        }
133                }
134        }
135
136        @Override
137        public boolean isRunning() {
138                return this.running;
139        }
140
141
142        @Override
143        public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
144                        @Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
145
146                ServerHttpRequest request = exchange.getRequest();
147                ServerHttpResponse response = exchange.getResponse();
148
149                HttpServletRequest servletRequest = getNativeRequest(request);
150                HttpServletResponse servletResponse = getNativeResponse(response);
151
152                HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
153                DataBufferFactory factory = response.bufferFactory();
154
155                JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(
156                                handler, session -> new JettyWebSocketSession(session, handshakeInfo, factory));
157
158                startLazily(servletRequest);
159
160                Assert.state(this.factory != null, "No WebSocketServerFactory available");
161                boolean isUpgrade = this.factory.isUpgradeRequest(servletRequest, servletResponse);
162                Assert.isTrue(isUpgrade, "Not a WebSocket handshake");
163
164                // Trigger WebFlux preCommit actions and upgrade
165                return exchange.getResponse().setComplete()
166                                .then(Mono.fromCallable(() -> {
167                                        try {
168                                                adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol));
169                                                this.factory.acceptWebSocket(servletRequest, servletResponse);
170                                        }
171                                        finally {
172                                                adapterHolder.remove();
173                                        }
174                                        return null;
175                                }));
176        }
177
178        private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
179                if (request instanceof AbstractServerHttpRequest) {
180                        return ((AbstractServerHttpRequest) request).getNativeRequest();
181                }
182                else if (request instanceof ServerHttpRequestDecorator) {
183                        return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
184                }
185                else {
186                        throw new IllegalArgumentException(
187                                        "Couldn't find HttpServletRequest in " + request.getClass().getName());
188                }
189        }
190
191        private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
192                if (response instanceof AbstractServerHttpResponse) {
193                        return ((AbstractServerHttpResponse) response).getNativeResponse();
194                }
195                else if (response instanceof ServerHttpResponseDecorator) {
196                        return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
197                }
198                else {
199                        throw new IllegalArgumentException(
200                                        "Couldn't find HttpServletResponse in " + response.getClass().getName());
201                }
202        }
203
204        private void startLazily(HttpServletRequest request) {
205                if (isRunning()) {
206                        return;
207                }
208                synchronized (this.lifecycleMonitor) {
209                        if (!isRunning()) {
210                                this.servletContext = request.getServletContext();
211                                start();
212                        }
213                }
214        }
215
216
217        private static class WebSocketHandlerContainer {
218
219                private final JettyWebSocketHandlerAdapter adapter;
220
221                @Nullable
222                private final String protocol;
223
224                public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter adapter, @Nullable String protocol) {
225                        this.adapter = adapter;
226                        this.protocol = protocol;
227                }
228
229                public JettyWebSocketHandlerAdapter getAdapter() {
230                        return this.adapter;
231                }
232
233                @Nullable
234                public String getProtocol() {
235                        return this.protocol;
236                }
237        }
238
239}