001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.jetty;
018
019import java.io.IOException;
020import java.lang.reflect.Method;
021import java.security.Principal;
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.List;
025import java.util.Map;
026import java.util.Set;
027
028import javax.servlet.ServletContext;
029import javax.servlet.http.HttpServletRequest;
030import javax.servlet.http.HttpServletResponse;
031
032import org.eclipse.jetty.websocket.api.WebSocketPolicy;
033import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
034import org.eclipse.jetty.websocket.server.HandshakeRFC6455;
035import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
036
037import org.springframework.context.Lifecycle;
038import org.springframework.core.NamedThreadLocal;
039import org.springframework.http.server.ServerHttpRequest;
040import org.springframework.http.server.ServerHttpResponse;
041import org.springframework.http.server.ServletServerHttpRequest;
042import org.springframework.http.server.ServletServerHttpResponse;
043import org.springframework.lang.Nullable;
044import org.springframework.util.Assert;
045import org.springframework.util.ClassUtils;
046import org.springframework.util.CollectionUtils;
047import org.springframework.util.ReflectionUtils;
048import org.springframework.web.context.ServletContextAware;
049import org.springframework.web.socket.WebSocketExtension;
050import org.springframework.web.socket.WebSocketHandler;
051import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter;
052import org.springframework.web.socket.adapter.jetty.JettyWebSocketSession;
053import org.springframework.web.socket.adapter.jetty.WebSocketToJettyExtensionConfigAdapter;
054import org.springframework.web.socket.server.HandshakeFailureException;
055import org.springframework.web.socket.server.RequestUpgradeStrategy;
056
057/**
058 * A {@link RequestUpgradeStrategy} for use with Jetty 9.4. Based on Jetty's
059 * internal {@code org.eclipse.jetty.websocket.server.WebSocketHandler} class.
060 *
061 * @author Phillip Webb
062 * @author Rossen Stoyanchev
063 * @author Brian Clozel
064 * @author Juergen Hoeller
065 * @since 4.0
066 */
067public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, ServletContextAware, Lifecycle {
068
069        private static final ThreadLocal<WebSocketHandlerContainer> containerHolder =
070                        new NamedThreadLocal<>("WebSocketHandlerContainer");
071
072        @Nullable
073        private WebSocketPolicy policy;
074
075        @Nullable
076        private volatile WebSocketServerFactory factory;
077
078        @Nullable
079        private ServletContext servletContext;
080
081        private volatile boolean running = false;
082
083        @Nullable
084        private volatile List<WebSocketExtension> supportedExtensions;
085
086
087        /**
088         * Default constructor that creates {@link WebSocketServerFactory} through
089         * its default constructor thus using a default {@link WebSocketPolicy}.
090         */
091        public JettyRequestUpgradeStrategy() {
092                this.policy = WebSocketPolicy.newServerPolicy();
093        }
094
095        /**
096         * A constructor accepting a {@link WebSocketPolicy} to be used when
097         * creating the {@link WebSocketServerFactory} instance.
098         * @param policy the policy to use
099         * @since 4.3.5
100         */
101        public JettyRequestUpgradeStrategy(WebSocketPolicy policy) {
102                Assert.notNull(policy, "WebSocketPolicy must not be null");
103                this.policy = policy;
104        }
105
106        /**
107         * A constructor accepting a {@link WebSocketServerFactory}.
108         * @param factory the pre-configured factory to use
109         */
110        public JettyRequestUpgradeStrategy(WebSocketServerFactory factory) {
111                Assert.notNull(factory, "WebSocketServerFactory must not be null");
112                this.factory = factory;
113        }
114
115
116        @Override
117        public void setServletContext(ServletContext servletContext) {
118                this.servletContext = servletContext;
119        }
120
121        @Override
122        public void start() {
123                if (!isRunning()) {
124                        this.running = true;
125                        try {
126                                WebSocketServerFactory factory = this.factory;
127                                if (factory == null) {
128                                        Assert.state(this.servletContext != null, "No ServletContext set");
129                                        factory = new WebSocketServerFactory(this.servletContext, this.policy);
130                                        this.factory = factory;
131                                }
132                                factory.setCreator((request, response) -> {
133                                        WebSocketHandlerContainer container = containerHolder.get();
134                                        Assert.state(container != null, "Expected WebSocketHandlerContainer");
135                                        response.setAcceptedSubProtocol(container.getSelectedProtocol());
136                                        response.setExtensions(container.getExtensionConfigs());
137                                        return container.getHandler();
138                                });
139                                factory.start();
140                        }
141                        catch (Throwable ex) {
142                                throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex);
143                        }
144                }
145        }
146
147        @Override
148        public void stop() {
149                if (isRunning()) {
150                        this.running = false;
151                        WebSocketServerFactory factory = this.factory;
152                        if (factory != null) {
153                                try {
154                                        factory.stop();
155                                }
156                                catch (Throwable ex) {
157                                        throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex);
158                                }
159                        }
160                }
161        }
162
163        @Override
164        public boolean isRunning() {
165                return this.running;
166        }
167
168
169        @Override
170        public String[] getSupportedVersions() {
171                return new String[] { String.valueOf(HandshakeRFC6455.VERSION) };
172        }
173
174        @Override
175        public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
176                List<WebSocketExtension> extensions = this.supportedExtensions;
177                if (extensions == null) {
178                        extensions = buildWebSocketExtensions();
179                        this.supportedExtensions = extensions;
180                }
181                return extensions;
182        }
183
184        private List<WebSocketExtension> buildWebSocketExtensions() {
185                Set<String> names = getExtensionNames();
186                List<WebSocketExtension> result = new ArrayList<>(names.size());
187                for (String name : names) {
188                        result.add(new WebSocketExtension(name));
189                }
190                return result;
191        }
192
193        @SuppressWarnings({"unchecked", "deprecation"})
194        private Set<String> getExtensionNames() {
195                WebSocketServerFactory factory = this.factory;
196                Assert.state(factory != null, "No WebSocketServerFactory available");
197                try {
198                        return factory.getAvailableExtensionNames();
199                }
200                catch (IncompatibleClassChangeError ex) {
201                        // Fallback for versions prior to 9.4.21:
202                        // 9.4.20.v20190813: ExtensionFactory (abstract class -> interface)
203                        // 9.4.21.v20190926: ExtensionFactory (interface -> abstract class) + deprecated
204                        Class<?> clazz = org.eclipse.jetty.websocket.api.extensions.ExtensionFactory.class;
205                        Method method = ClassUtils.getMethod(clazz, "getExtensionNames");
206                        Set<String> result = (Set<String>) ReflectionUtils.invokeMethod(method, factory.getExtensionFactory());
207                        return (result != null ? result : Collections.emptySet());
208                }
209        }
210
211        @Override
212        public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
213                        @Nullable String selectedProtocol, List<WebSocketExtension> selectedExtensions, @Nullable Principal user,
214                        WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
215
216                Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required");
217                HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
218
219                Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required");
220                HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
221
222                WebSocketServerFactory factory = this.factory;
223                Assert.state(factory != null, "No WebSocketServerFactory available");
224                Assert.isTrue(factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
225
226                JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
227                JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session);
228
229                WebSocketHandlerContainer container =
230                                new WebSocketHandlerContainer(handlerAdapter, selectedProtocol, selectedExtensions);
231
232                try {
233                        containerHolder.set(container);
234                        factory.acceptWebSocket(servletRequest, servletResponse);
235                }
236                catch (IOException ex) {
237                        throw new HandshakeFailureException(
238                                        "Response update failed during upgrade to WebSocket: " + request.getURI(), ex);
239                }
240                finally {
241                        containerHolder.remove();
242                }
243        }
244
245
246        private static class WebSocketHandlerContainer {
247
248                private final JettyWebSocketHandlerAdapter handler;
249
250                @Nullable
251                private final String selectedProtocol;
252
253                private final List<ExtensionConfig> extensionConfigs;
254
255                public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler,
256                                @Nullable String protocol, List<WebSocketExtension> extensions) {
257
258                        this.handler = handler;
259                        this.selectedProtocol = protocol;
260                        if (CollectionUtils.isEmpty(extensions)) {
261                                this.extensionConfigs = new ArrayList<>(0);
262                        }
263                        else {
264                                this.extensionConfigs = new ArrayList<>(extensions.size());
265                                for (WebSocketExtension extension : extensions) {
266                                        this.extensionConfigs.add(new WebSocketToJettyExtensionConfigAdapter(extension));
267                                }
268                        }
269                }
270
271                public JettyWebSocketHandlerAdapter getHandler() {
272                        return this.handler;
273                }
274
275                @Nullable
276                public String getSelectedProtocol() {
277                        return this.selectedProtocol;
278                }
279
280                public List<ExtensionConfig> getExtensionConfigs() {
281                        return this.extensionConfigs;
282                }
283        }
284
285}