001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.io.IOException; 020import java.lang.reflect.Constructor; 021import java.lang.reflect.Method; 022import java.net.URI; 023import java.util.ArrayList; 024import java.util.Collections; 025import java.util.List; 026import java.util.Random; 027import javax.servlet.ServletException; 028import javax.servlet.http.HttpServletRequest; 029import javax.servlet.http.HttpServletResponse; 030import javax.websocket.DeploymentException; 031import javax.websocket.Endpoint; 032import javax.websocket.EndpointConfig; 033import javax.websocket.Extension; 034import javax.websocket.WebSocketContainer; 035 036import org.glassfish.tyrus.core.ComponentProviderService; 037import org.glassfish.tyrus.core.RequestContext; 038import org.glassfish.tyrus.core.TyrusEndpoint; 039import org.glassfish.tyrus.core.TyrusEndpointWrapper; 040import org.glassfish.tyrus.core.TyrusUpgradeResponse; 041import org.glassfish.tyrus.core.TyrusWebSocketEngine; 042import org.glassfish.tyrus.core.Version; 043import org.glassfish.tyrus.core.WebSocketApplication; 044import org.glassfish.tyrus.server.TyrusServerContainer; 045import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; 046 047import org.springframework.beans.DirectFieldAccessor; 048import org.springframework.http.HttpHeaders; 049import org.springframework.http.server.ServerHttpRequest; 050import org.springframework.http.server.ServerHttpResponse; 051import org.springframework.util.ReflectionUtils; 052import org.springframework.util.StringUtils; 053import org.springframework.web.socket.WebSocketExtension; 054import org.springframework.web.socket.server.HandshakeFailureException; 055 056import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.*; 057 058/** 059 * A base class for {@code RequestUpgradeStrategy} implementations on top of 060 * JSR-356 based servers which include Tyrus as their WebSocket engine. 061 * 062 * <p>Works with Tyrus 1.3.5 (WebLogic 12.1.3), Tyrus 1.7 (GlassFish 4.1.0), 063 * Tyrus 1.11 (WebLogic 12.2.1), and Tyrus 1.12 (GlassFish 4.1.1). 064 * 065 * @author Rossen Stoyanchev 066 * @author Brian Clozel 067 * @since 4.1 068 * @see <a href="https://tyrus.java.net/">Project Tyrus</a> 069 */ 070public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { 071 072 private static final Random random = new Random(); 073 074 private final ComponentProviderService componentProvider = ComponentProviderService.create(); 075 076 077 @Override 078 public String[] getSupportedVersions() { 079 return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ","); 080 } 081 082 @Override 083 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) { 084 try { 085 return super.getInstalledExtensions(container); 086 } 087 catch (UnsupportedOperationException ex) { 088 return new ArrayList<WebSocketExtension>(0); 089 } 090 } 091 092 @Override 093 public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 094 String selectedProtocol, List<Extension> extensions, Endpoint endpoint) 095 throws HandshakeFailureException { 096 097 HttpServletRequest servletRequest = getHttpServletRequest(request); 098 HttpServletResponse servletResponse = getHttpServletResponse(response); 099 100 TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest); 101 TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine(); 102 Object tyrusEndpoint = null; 103 boolean success; 104 105 try { 106 // Shouldn't matter for processing but must be unique 107 String path = "/" + random.nextLong(); 108 tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine); 109 getEndpointHelper().register(engine, tyrusEndpoint); 110 111 HttpHeaders headers = request.getHeaders(); 112 RequestContext requestContext = createRequestContext(servletRequest, path, headers); 113 TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); 114 UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); 115 success = SUCCESS.equals(upgradeInfo.getStatus()); 116 if (success) { 117 if (logger.isTraceEnabled()) { 118 logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders()); 119 } 120 handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse); 121 } 122 } 123 catch (Exception ex) { 124 unregisterTyrusEndpoint(engine, tyrusEndpoint); 125 throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex); 126 } 127 128 unregisterTyrusEndpoint(engine, tyrusEndpoint); 129 if (!success) { 130 throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI()); 131 } 132 } 133 134 private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, String protocol, 135 List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine) 136 throws DeploymentException { 137 138 ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); 139 endpointConfig.setSubprotocols(Collections.singletonList(protocol)); 140 endpointConfig.setExtensions(extensions); 141 return getEndpointHelper().createdEndpoint(endpointConfig, this.componentProvider, container, engine); 142 } 143 144 private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) { 145 RequestContext context = 146 RequestContext.Builder.create() 147 .requestURI(URI.create(endpointPath)) 148 .userPrincipal(request.getUserPrincipal()) 149 .secure(request.isSecure()) 150 // .remoteAddr(request.getRemoteAddr()) # Not available in 1.3.5 151 .build(); 152 for (String header : headers.keySet()) { 153 context.getHeaders().put(header, headers.get(header)); 154 } 155 return context; 156 } 157 158 private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, Object tyrusEndpoint) { 159 if (tyrusEndpoint != null) { 160 try { 161 getEndpointHelper().unregister(engine, tyrusEndpoint); 162 } 163 catch (Throwable ex) { 164 // ignore 165 } 166 } 167 } 168 169 protected abstract TyrusEndpointHelper getEndpointHelper(); 170 171 protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response, 172 UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException; 173 174 175 /** 176 * Helps with the creation, registration, and un-registration of endpoints. 177 */ 178 protected interface TyrusEndpointHelper { 179 180 Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, 181 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException; 182 183 void register(TyrusWebSocketEngine engine, Object endpoint); 184 185 void unregister(TyrusWebSocketEngine engine, Object endpoint); 186 } 187 188 189 protected static class Tyrus17EndpointHelper implements TyrusEndpointHelper { 190 191 private static final Constructor<?> constructor; 192 193 private static boolean constructorWithBooleanArgument; 194 195 private static final Method registerMethod; 196 197 private static final Method unregisterMethod; 198 199 static { 200 try { 201 constructor = getEndpointConstructor(); 202 int parameterCount = constructor.getParameterTypes().length; 203 constructorWithBooleanArgument = (parameterCount == 10); 204 if (!constructorWithBooleanArgument && parameterCount != 9) { 205 throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments"); 206 } 207 registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class); 208 unregisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class); 209 ReflectionUtils.makeAccessible(registerMethod); 210 } 211 catch (Exception ex) { 212 throw new IllegalStateException("No compatible Tyrus version found", ex); 213 } 214 } 215 216 private static Constructor<?> getEndpointConstructor() { 217 for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) { 218 Class<?>[] types = current.getParameterTypes(); 219 if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) { 220 return current; 221 } 222 } 223 throw new IllegalStateException("No compatible Tyrus version found"); 224 } 225 226 227 @Override 228 public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, 229 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException { 230 231 DirectFieldAccessor accessor = new DirectFieldAccessor(engine); 232 Object sessionListener = accessor.getPropertyValue("sessionListener"); 233 Object clusterContext = accessor.getPropertyValue("clusterContext"); 234 try { 235 if (constructorWithBooleanArgument) { 236 // Tyrus 1.11+ 237 return constructor.newInstance(registration.getEndpoint(), registration, provider, container, 238 "/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE); 239 } 240 else { 241 return constructor.newInstance(registration.getEndpoint(), registration, provider, container, 242 "/", registration.getConfigurator(), sessionListener, clusterContext, null); 243 } 244 } 245 catch (Exception ex) { 246 throw new HandshakeFailureException("Failed to register " + registration, ex); 247 } 248 } 249 250 @Override 251 public void register(TyrusWebSocketEngine engine, Object endpoint) { 252 try { 253 registerMethod.invoke(engine, endpoint); 254 } 255 catch (Exception ex) { 256 throw new HandshakeFailureException("Failed to register " + endpoint, ex); 257 } 258 } 259 260 @Override 261 public void unregister(TyrusWebSocketEngine engine, Object endpoint) { 262 try { 263 unregisterMethod.invoke(engine, endpoint); 264 } 265 catch (Exception ex) { 266 throw new HandshakeFailureException("Failed to unregister " + endpoint, ex); 267 } 268 } 269 } 270 271 272 protected static class Tyrus135EndpointHelper implements TyrusEndpointHelper { 273 274 private static final Method registerMethod; 275 276 static { 277 try { 278 registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", WebSocketApplication.class); 279 ReflectionUtils.makeAccessible(registerMethod); 280 } 281 catch (Exception ex) { 282 throw new IllegalStateException("No compatible Tyrus version found", ex); 283 } 284 } 285 286 @Override 287 public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, 288 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException { 289 290 TyrusEndpointWrapper endpointWrapper = new TyrusEndpointWrapper(registration.getEndpoint(), 291 registration, provider, container, "/", registration.getConfigurator()); 292 293 return new TyrusEndpoint(endpointWrapper); 294 } 295 296 @Override 297 public void register(TyrusWebSocketEngine engine, Object endpoint) { 298 try { 299 registerMethod.invoke(engine, endpoint); 300 } 301 catch (Exception ex) { 302 throw new HandshakeFailureException("Failed to register " + endpoint, ex); 303 } 304 } 305 306 @Override 307 public void unregister(TyrusWebSocketEngine engine, Object endpoint) { 308 engine.unregister((TyrusEndpoint) endpoint); 309 } 310 } 311 312}