001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.io.IOException;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.Method;
022import java.net.URI;
023import java.util.ArrayList;
024import java.util.Collections;
025import java.util.List;
026import java.util.Random;
027
028import javax.servlet.ServletException;
029import javax.servlet.http.HttpServletRequest;
030import javax.servlet.http.HttpServletResponse;
031import javax.websocket.DeploymentException;
032import javax.websocket.Endpoint;
033import javax.websocket.EndpointConfig;
034import javax.websocket.Extension;
035import javax.websocket.WebSocketContainer;
036
037import org.glassfish.tyrus.core.ComponentProviderService;
038import org.glassfish.tyrus.core.RequestContext;
039import org.glassfish.tyrus.core.TyrusEndpointWrapper;
040import org.glassfish.tyrus.core.TyrusUpgradeResponse;
041import org.glassfish.tyrus.core.TyrusWebSocketEngine;
042import org.glassfish.tyrus.core.Version;
043import org.glassfish.tyrus.server.TyrusServerContainer;
044import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;
045
046import org.springframework.beans.DirectFieldAccessor;
047import org.springframework.http.HttpHeaders;
048import org.springframework.http.server.ServerHttpRequest;
049import org.springframework.http.server.ServerHttpResponse;
050import org.springframework.lang.Nullable;
051import org.springframework.util.ReflectionUtils;
052import org.springframework.util.StringUtils;
053import org.springframework.web.socket.WebSocketExtension;
054import org.springframework.web.socket.server.HandshakeFailureException;
055
056import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.SUCCESS;
057
058/**
059 * A base class for {@code RequestUpgradeStrategy} implementations on top of
060 * JSR-356 based servers which include Tyrus as their WebSocket engine.
061 *
062 * <p>Works with Tyrus 1.11 (WebLogic 12.2.1) and Tyrus 1.12 (GlassFish 4.1.1).
063 *
064 * @author Rossen Stoyanchev
065 * @author Brian Clozel
066 * @author Juergen Hoeller
067 * @since 4.1
068 * @see <a href="https://tyrus.java.net/">Project Tyrus</a>
069 */
070public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
071
072        private static final Random random = new Random();
073
074        private static final Constructor<?> constructor;
075
076        private static boolean constructorWithBooleanArgument;
077
078        private static final Method registerMethod;
079
080        private static final Method unRegisterMethod;
081
082        static {
083                try {
084                        constructor = getEndpointConstructor();
085                        int parameterCount = constructor.getParameterCount();
086                        constructorWithBooleanArgument = (parameterCount == 10);
087                        if (!constructorWithBooleanArgument && parameterCount != 9) {
088                                throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments");
089                        }
090                        registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class);
091                        unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class);
092                        ReflectionUtils.makeAccessible(registerMethod);
093                }
094                catch (Exception ex) {
095                        throw new IllegalStateException("No compatible Tyrus version found", ex);
096                }
097        }
098
099        private static Constructor<?> getEndpointConstructor() {
100                for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) {
101                        Class<?>[] types = current.getParameterTypes();
102                        if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) {
103                                return current;
104                        }
105                }
106                throw new IllegalStateException("No compatible Tyrus version found");
107        }
108
109
110        private final ComponentProviderService componentProvider = ComponentProviderService.create();
111
112
113        @Override
114        public String[] getSupportedVersions() {
115                return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ",");
116        }
117
118        @Override
119        protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
120                try {
121                        return super.getInstalledExtensions(container);
122                }
123                catch (UnsupportedOperationException ex) {
124                        return new ArrayList<>(0);
125                }
126        }
127
128        @Override
129        public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
130                        @Nullable String selectedProtocol, List<Extension> extensions, Endpoint endpoint)
131                        throws HandshakeFailureException {
132
133                HttpServletRequest servletRequest = getHttpServletRequest(request);
134                HttpServletResponse servletResponse = getHttpServletResponse(response);
135
136                TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest);
137                TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine();
138                Object tyrusEndpoint = null;
139                boolean success;
140
141                try {
142                        // Shouldn't matter for processing but must be unique
143                        String path = "/" + random.nextLong();
144                        tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine);
145                        register(engine, tyrusEndpoint);
146
147                        HttpHeaders headers = request.getHeaders();
148                        RequestContext requestContext = createRequestContext(servletRequest, path, headers);
149                        TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
150                        UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
151                        success = SUCCESS.equals(upgradeInfo.getStatus());
152                        if (success) {
153                                if (logger.isTraceEnabled()) {
154                                        logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders());
155                                }
156                                handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse);
157                        }
158                }
159                catch (Exception ex) {
160                        unregisterTyrusEndpoint(engine, tyrusEndpoint);
161                        throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex);
162                }
163
164                unregisterTyrusEndpoint(engine, tyrusEndpoint);
165                if (!success) {
166                        throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI());
167                }
168        }
169
170        private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, @Nullable String protocol,
171                        List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine)
172                        throws DeploymentException {
173
174                ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint);
175                endpointConfig.setSubprotocols(Collections.singletonList(protocol));
176                endpointConfig.setExtensions(extensions);
177                return createEndpoint(endpointConfig, this.componentProvider, container, engine);
178        }
179
180        private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) {
181                RequestContext context =
182                                RequestContext.Builder.create()
183                                                .requestURI(URI.create(endpointPath))
184                                                .userPrincipal(request.getUserPrincipal())
185                                                .secure(request.isSecure())
186                                                .remoteAddr(request.getRemoteAddr())
187                                                .build();
188                headers.forEach((header, value) -> context.getHeaders().put(header, value));
189                return context;
190        }
191
192        private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, @Nullable Object tyrusEndpoint) {
193                if (tyrusEndpoint != null) {
194                        try {
195                                unregister(engine, tyrusEndpoint);
196                        }
197                        catch (Throwable ex) {
198                                // ignore
199                        }
200                }
201        }
202
203        private Object createEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
204                        WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {
205
206                DirectFieldAccessor accessor = new DirectFieldAccessor(engine);
207                Object sessionListener = accessor.getPropertyValue("sessionListener");
208                Object clusterContext = accessor.getPropertyValue("clusterContext");
209                try {
210                        if (constructorWithBooleanArgument) {
211                                // Tyrus 1.11+
212                                return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
213                                                "/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE);
214                        }
215                        else {
216                                return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
217                                                "/", registration.getConfigurator(), sessionListener, clusterContext, null);
218                        }
219                }
220                catch (Exception ex) {
221                        throw new HandshakeFailureException("Failed to register " + registration, ex);
222                }
223        }
224
225        private void register(TyrusWebSocketEngine engine, Object endpoint) {
226                try {
227                        registerMethod.invoke(engine, endpoint);
228                }
229                catch (Exception ex) {
230                        throw new HandshakeFailureException("Failed to register " + endpoint, ex);
231                }
232        }
233
234        private void unregister(TyrusWebSocketEngine engine, Object endpoint) {
235                try {
236                        unRegisterMethod.invoke(engine, endpoint);
237                }
238                catch (Exception ex) {
239                        throw new HandshakeFailureException("Failed to unregister " + endpoint, ex);
240                }
241        }
242
243
244        protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response,
245                        UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException;
246
247}