001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.jetty;
018
019import java.io.IOException;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.List;
023import java.util.Map;
024import java.util.Set;
025import javax.servlet.ServletContext;
026import javax.servlet.http.HttpServletRequest;
027import javax.servlet.http.HttpServletResponse;
028
029import org.eclipse.jetty.websocket.api.WebSocketPolicy;
030import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
031import org.eclipse.jetty.websocket.server.HandshakeRFC6455;
032import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
033import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
034import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
035import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
036
037import org.springframework.context.Lifecycle;
038import org.springframework.core.NamedThreadLocal;
039import org.springframework.http.server.ServerHttpRequest;
040import org.springframework.http.server.ServerHttpResponse;
041import org.springframework.http.server.ServletServerHttpRequest;
042import org.springframework.http.server.ServletServerHttpResponse;
043import org.springframework.util.Assert;
044import org.springframework.util.ClassUtils;
045import org.springframework.util.CollectionUtils;
046import org.springframework.web.context.ServletContextAware;
047import org.springframework.web.socket.WebSocketExtension;
048import org.springframework.web.socket.WebSocketHandler;
049import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter;
050import org.springframework.web.socket.adapter.jetty.JettyWebSocketSession;
051import org.springframework.web.socket.adapter.jetty.WebSocketToJettyExtensionConfigAdapter;
052import org.springframework.web.socket.server.HandshakeFailureException;
053import org.springframework.web.socket.server.RequestUpgradeStrategy;
054
055/**
056 * A {@link RequestUpgradeStrategy} for use with Jetty 9.1-9.4. Based on
057 * Jetty's internal {@code org.eclipse.jetty.websocket.server.WebSocketHandler} class.
058 *
059 * @author Phillip Webb
060 * @author Rossen Stoyanchev
061 * @author Brian Clozel
062 * @author Juergen Hoeller
063 * @since 4.0
064 */
065public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, ServletContextAware, Lifecycle {
066
067        private static final ThreadLocal<WebSocketHandlerContainer> containerHolder =
068                        new NamedThreadLocal<WebSocketHandlerContainer>("WebSocketHandlerContainer");
069
070
071        // Configurable factory adapter due to Jetty 9.3.15+ API differences:
072        // using WebSocketServerFactory(ServletContext) as a version indicator
073        private final WebSocketServerFactoryAdapter factoryAdapter =
074                        (ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class) ?
075                                        new ModernJettyWebSocketServerFactoryAdapter() : new LegacyJettyWebSocketServerFactoryAdapter());
076
077        private ServletContext servletContext;
078
079        private volatile boolean running = false;
080
081        private volatile List<WebSocketExtension> supportedExtensions;
082
083
084        /**
085         * Default constructor that creates {@link WebSocketServerFactory} through
086         * its default constructor thus using a default {@link WebSocketPolicy}.
087         */
088        public JettyRequestUpgradeStrategy() {
089                this.factoryAdapter.setPolicy(WebSocketPolicy.newServerPolicy());
090        }
091
092        /**
093         * A constructor accepting a {@link WebSocketPolicy} to be used when
094         * creating the {@link WebSocketServerFactory} instance.
095         * @param policy the policy to use
096         * @since 4.3.5
097         */
098        public JettyRequestUpgradeStrategy(WebSocketPolicy policy) {
099                Assert.notNull(policy, "WebSocketPolicy must not be null");
100                this.factoryAdapter.setPolicy(policy);
101        }
102
103        /**
104         * A constructor accepting a {@link WebSocketServerFactory}.
105         * @param factory the pre-configured factory to use
106         */
107        public JettyRequestUpgradeStrategy(WebSocketServerFactory factory) {
108                Assert.notNull(factory, "WebSocketServerFactory must not be null");
109                this.factoryAdapter.setFactory(factory);
110        }
111
112
113        @Override
114        public void setServletContext(ServletContext servletContext) {
115                this.servletContext = servletContext;
116        }
117
118        @Override
119        public void start() {
120                if (!isRunning()) {
121                        this.running = true;
122                        try {
123                                this.factoryAdapter.start();
124                        }
125                        catch (Throwable ex) {
126                                throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex);
127                        }
128                }
129        }
130
131        @Override
132        public void stop() {
133                if (isRunning()) {
134                        this.running = false;
135                        try {
136                                this.factoryAdapter.stop();
137                        }
138                        catch (Throwable ex) {
139                                throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex);
140                        }
141                }
142        }
143
144        @Override
145        public boolean isRunning() {
146                return this.running;
147        }
148
149
150        @Override
151        public String[] getSupportedVersions() {
152                return new String[] { String.valueOf(HandshakeRFC6455.VERSION) };
153        }
154
155        @Override
156        public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
157                if (this.supportedExtensions == null) {
158                        this.supportedExtensions = buildWebSocketExtensions();
159                }
160                return this.supportedExtensions;
161        }
162
163        private List<WebSocketExtension> buildWebSocketExtensions() {
164                Set<String> names = this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames();
165                List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(names.size());
166                for (String name : names) {
167                        result.add(new WebSocketExtension(name));
168                }
169                return result;
170        }
171
172        @Override
173        public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
174                        String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
175                        WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
176
177                Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required");
178                HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
179
180                Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required");
181                HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
182
183                Assert.isTrue(this.factoryAdapter.getFactory().isUpgradeRequest(servletRequest, servletResponse),
184                                "Not a WebSocket handshake");
185
186                JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
187                JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session);
188
189                WebSocketHandlerContainer container =
190                                new WebSocketHandlerContainer(handlerAdapter, selectedProtocol, selectedExtensions);
191
192                try {
193                        containerHolder.set(container);
194                        this.factoryAdapter.getFactory().acceptWebSocket(servletRequest, servletResponse);
195                }
196                catch (IOException ex) {
197                        throw new HandshakeFailureException(
198                                        "Response update failed during upgrade to WebSocket: " + request.getURI(), ex);
199                }
200                finally {
201                        containerHolder.remove();
202                }
203        }
204
205
206        private static class WebSocketHandlerContainer {
207
208                private final JettyWebSocketHandlerAdapter handler;
209
210                private final String selectedProtocol;
211
212                private final List<ExtensionConfig> extensionConfigs;
213
214                public WebSocketHandlerContainer(
215                                JettyWebSocketHandlerAdapter handler, String protocol, List<WebSocketExtension> extensions) {
216
217                        this.handler = handler;
218                        this.selectedProtocol = protocol;
219                        if (CollectionUtils.isEmpty(extensions)) {
220                                this.extensionConfigs = new ArrayList<ExtensionConfig>(0);
221                        }
222                        else {
223                                this.extensionConfigs = new ArrayList<ExtensionConfig>(extensions.size());
224                                for (WebSocketExtension extension : extensions) {
225                                        this.extensionConfigs.add(new WebSocketToJettyExtensionConfigAdapter(extension));
226                                }
227                        }
228                }
229
230                public JettyWebSocketHandlerAdapter getHandler() {
231                        return this.handler;
232                }
233
234                public String getSelectedProtocol() {
235                        return this.selectedProtocol;
236                }
237
238                public List<ExtensionConfig> getExtensionConfigs() {
239                        return this.extensionConfigs;
240                }
241        }
242
243
244        private static abstract class WebSocketServerFactoryAdapter {
245
246                private WebSocketPolicy policy;
247
248                private WebSocketServerFactory factory;
249
250                public void setPolicy(WebSocketPolicy policy) {
251                        this.policy = policy;
252                }
253
254                public void setFactory(WebSocketServerFactory factory) {
255                        this.factory = factory;
256                }
257
258                public WebSocketServerFactory getFactory() {
259                        return this.factory;
260                }
261
262                public void start() throws Exception {
263                        if (this.factory == null) {
264                                this.factory = createFactory(this.policy);
265                        }
266                        this.factory.setCreator(new WebSocketCreator() {
267                                @Override
268                                public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) {
269                                        WebSocketHandlerContainer container = containerHolder.get();
270                                        Assert.state(container != null, "Expected WebSocketHandlerContainer");
271                                        response.setAcceptedSubProtocol(container.getSelectedProtocol());
272                                        response.setExtensions(container.getExtensionConfigs());
273                                        return container.getHandler();
274                                }
275                        });
276                        startFactory(this.factory);
277                }
278
279                public void stop() throws Exception {
280                        if (this.factory != null) {
281                                stopFactory(this.factory);
282                        }
283                }
284
285                protected abstract WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception;
286
287                protected abstract void startFactory(WebSocketServerFactory factory) throws Exception;
288
289                protected abstract void stopFactory(WebSocketServerFactory factory) throws Exception;
290        }
291
292
293        // Jetty 9.3.15+
294        private class ModernJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter {
295
296                @Override
297                protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception {
298                        return new WebSocketServerFactory(servletContext, policy);
299                }
300
301                @Override
302                protected void startFactory(WebSocketServerFactory factory) throws Exception {
303                        factory.start();
304                }
305
306                @Override
307                protected void stopFactory(WebSocketServerFactory factory) throws Exception {
308                        factory.stop();
309                }
310        }
311
312
313        // Jetty <9.3.15
314        private class LegacyJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter {
315
316                @Override
317                protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception {
318                        return WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class).newInstance(policy);
319                }
320
321                @Override
322                protected void startFactory(WebSocketServerFactory factory) throws Exception {
323                        try {
324                                WebSocketServerFactory.class.getMethod("init", ServletContext.class).invoke(factory, servletContext);
325                        }
326                        catch (NoSuchMethodException ex) {
327                                // Jetty 9.1/9.2
328                                WebSocketServerFactory.class.getMethod("init").invoke(factory);
329                        }
330                }
331
332                @Override
333                protected void stopFactory(WebSocketServerFactory factory) throws Exception {
334                        WebSocketServerFactory.class.getMethod("cleanup").invoke(factory);
335                }
336        }
337
338}