001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.io.IOException;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.Method;
022import java.net.URI;
023import java.util.ArrayList;
024import java.util.Collections;
025import java.util.List;
026import java.util.Random;
027import javax.servlet.ServletException;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpServletResponse;
030import javax.websocket.DeploymentException;
031import javax.websocket.Endpoint;
032import javax.websocket.EndpointConfig;
033import javax.websocket.Extension;
034import javax.websocket.WebSocketContainer;
035
036import org.glassfish.tyrus.core.ComponentProviderService;
037import org.glassfish.tyrus.core.RequestContext;
038import org.glassfish.tyrus.core.TyrusEndpoint;
039import org.glassfish.tyrus.core.TyrusEndpointWrapper;
040import org.glassfish.tyrus.core.TyrusUpgradeResponse;
041import org.glassfish.tyrus.core.TyrusWebSocketEngine;
042import org.glassfish.tyrus.core.Version;
043import org.glassfish.tyrus.core.WebSocketApplication;
044import org.glassfish.tyrus.server.TyrusServerContainer;
045import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;
046
047import org.springframework.beans.DirectFieldAccessor;
048import org.springframework.http.HttpHeaders;
049import org.springframework.http.server.ServerHttpRequest;
050import org.springframework.http.server.ServerHttpResponse;
051import org.springframework.util.ReflectionUtils;
052import org.springframework.util.StringUtils;
053import org.springframework.web.socket.WebSocketExtension;
054import org.springframework.web.socket.server.HandshakeFailureException;
055
056import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.*;
057
058/**
059 * A base class for {@code RequestUpgradeStrategy} implementations on top of
060 * JSR-356 based servers which include Tyrus as their WebSocket engine.
061 *
062 * <p>Works with Tyrus 1.3.5 (WebLogic 12.1.3), Tyrus 1.7 (GlassFish 4.1.0),
063 * Tyrus 1.11 (WebLogic 12.2.1), and Tyrus 1.12 (GlassFish 4.1.1).
064 *
065 * @author Rossen Stoyanchev
066 * @author Brian Clozel
067 * @since 4.1
068 * @see <a href="https://tyrus.java.net/">Project Tyrus</a>
069 */
070public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
071
072        private static final Random random = new Random();
073
074        private final ComponentProviderService componentProvider = ComponentProviderService.create();
075
076
077        @Override
078        public String[] getSupportedVersions() {
079                return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ",");
080        }
081
082        @Override
083        protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
084                try {
085                        return super.getInstalledExtensions(container);
086                }
087                catch (UnsupportedOperationException ex) {
088                        return new ArrayList<WebSocketExtension>(0);
089                }
090        }
091
092        @Override
093        public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
094                        String selectedProtocol, List<Extension> extensions, Endpoint endpoint)
095                        throws HandshakeFailureException {
096
097                HttpServletRequest servletRequest = getHttpServletRequest(request);
098                HttpServletResponse servletResponse = getHttpServletResponse(response);
099
100                TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest);
101                TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine();
102                Object tyrusEndpoint = null;
103                boolean success;
104
105                try {
106                        // Shouldn't matter for processing but must be unique
107                        String path = "/" + random.nextLong();
108                        tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine);
109                        getEndpointHelper().register(engine, tyrusEndpoint);
110
111                        HttpHeaders headers = request.getHeaders();
112                        RequestContext requestContext = createRequestContext(servletRequest, path, headers);
113                        TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
114                        UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
115                        success = SUCCESS.equals(upgradeInfo.getStatus());
116                        if (success) {
117                                if (logger.isTraceEnabled()) {
118                                        logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders());
119                                }
120                                handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse);
121                        }
122                }
123                catch (Exception ex) {
124                        unregisterTyrusEndpoint(engine, tyrusEndpoint);
125                        throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex);
126                }
127
128                unregisterTyrusEndpoint(engine, tyrusEndpoint);
129                if (!success) {
130                        throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI());
131                }
132        }
133
134        private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, String protocol,
135                        List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine)
136                        throws DeploymentException {
137
138                ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint);
139                endpointConfig.setSubprotocols(Collections.singletonList(protocol));
140                endpointConfig.setExtensions(extensions);
141                return getEndpointHelper().createdEndpoint(endpointConfig, this.componentProvider, container, engine);
142        }
143
144        private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) {
145                RequestContext context =
146                                RequestContext.Builder.create()
147                                                .requestURI(URI.create(endpointPath))
148                                                .userPrincipal(request.getUserPrincipal())
149                                                .secure(request.isSecure())
150                                        //      .remoteAddr(request.getRemoteAddr())  # Not available in 1.3.5
151                                                .build();
152                for (String header : headers.keySet()) {
153                        context.getHeaders().put(header, headers.get(header));
154                }
155                return context;
156        }
157
158        private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, Object tyrusEndpoint) {
159                if (tyrusEndpoint != null) {
160                        try {
161                                getEndpointHelper().unregister(engine, tyrusEndpoint);
162                        }
163                        catch (Throwable ex) {
164                                // ignore
165                        }
166                }
167        }
168
169        protected abstract TyrusEndpointHelper getEndpointHelper();
170
171        protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response,
172                        UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException;
173
174
175        /**
176         * Helps with the creation, registration, and un-registration of endpoints.
177         */
178        protected interface TyrusEndpointHelper {
179
180                Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
181                                WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException;
182
183                void register(TyrusWebSocketEngine engine, Object endpoint);
184
185                void unregister(TyrusWebSocketEngine engine, Object endpoint);
186        }
187
188
189        protected static class Tyrus17EndpointHelper implements TyrusEndpointHelper {
190
191                private static final Constructor<?> constructor;
192
193                private static boolean constructorWithBooleanArgument;
194
195                private static final Method registerMethod;
196
197                private static final Method unregisterMethod;
198
199                static {
200                        try {
201                                constructor = getEndpointConstructor();
202                                int parameterCount = constructor.getParameterTypes().length;
203                                constructorWithBooleanArgument = (parameterCount == 10);
204                                if (!constructorWithBooleanArgument && parameterCount != 9) {
205                                        throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments");
206                                }
207                                registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class);
208                                unregisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class);
209                                ReflectionUtils.makeAccessible(registerMethod);
210                        }
211                        catch (Exception ex) {
212                                throw new IllegalStateException("No compatible Tyrus version found", ex);
213                        }
214                }
215
216                private static Constructor<?> getEndpointConstructor() {
217                        for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) {
218                                Class<?>[] types = current.getParameterTypes();
219                                if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) {
220                                        return current;
221                                }
222                        }
223                        throw new IllegalStateException("No compatible Tyrus version found");
224                }
225
226
227                @Override
228                public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
229                                WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {
230
231                        DirectFieldAccessor accessor = new DirectFieldAccessor(engine);
232                        Object sessionListener = accessor.getPropertyValue("sessionListener");
233                        Object clusterContext = accessor.getPropertyValue("clusterContext");
234                        try {
235                                if (constructorWithBooleanArgument) {
236                                        // Tyrus 1.11+
237                                        return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
238                                                        "/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE);
239                                }
240                                else {
241                                        return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
242                                                        "/", registration.getConfigurator(), sessionListener, clusterContext, null);
243                                }
244                        }
245                        catch (Exception ex) {
246                                throw new HandshakeFailureException("Failed to register " + registration, ex);
247                        }
248                }
249
250                @Override
251                public void register(TyrusWebSocketEngine engine, Object endpoint) {
252                        try {
253                                registerMethod.invoke(engine, endpoint);
254                        }
255                        catch (Exception ex) {
256                                throw new HandshakeFailureException("Failed to register " + endpoint, ex);
257                        }
258                }
259
260                @Override
261                public void unregister(TyrusWebSocketEngine engine, Object endpoint) {
262                        try {
263                                unregisterMethod.invoke(engine, endpoint);
264                        }
265                        catch (Exception ex) {
266                                throw new HandshakeFailureException("Failed to unregister " + endpoint, ex);
267                        }
268                }
269        }
270
271
272        protected static class Tyrus135EndpointHelper implements TyrusEndpointHelper {
273
274                private static final Method registerMethod;
275
276                static {
277                        try {
278                                registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", WebSocketApplication.class);
279                                ReflectionUtils.makeAccessible(registerMethod);
280                        }
281                        catch (Exception ex) {
282                                throw new IllegalStateException("No compatible Tyrus version found", ex);
283                        }
284                }
285
286                @Override
287                public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
288                                WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {
289
290                        TyrusEndpointWrapper endpointWrapper = new TyrusEndpointWrapper(registration.getEndpoint(),
291                                        registration, provider, container, "/", registration.getConfigurator());
292
293                        return new TyrusEndpoint(endpointWrapper);
294                }
295
296                @Override
297                public void register(TyrusWebSocketEngine engine, Object endpoint) {
298                        try {
299                                registerMethod.invoke(engine, endpoint);
300                        }
301                        catch (Exception ex) {
302                                throw new HandshakeFailureException("Failed to register " + endpoint, ex);
303                        }
304                }
305
306                @Override
307                public void unregister(TyrusWebSocketEngine engine, Object endpoint) {
308                        engine.unregister((TyrusEndpoint) endpoint);
309                }
310        }
311
312}