001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.io.IOException; 020import java.lang.reflect.Constructor; 021import java.lang.reflect.Method; 022import java.net.URI; 023import java.util.ArrayList; 024import java.util.Collections; 025import java.util.List; 026import java.util.Random; 027 028import javax.servlet.ServletException; 029import javax.servlet.http.HttpServletRequest; 030import javax.servlet.http.HttpServletResponse; 031import javax.websocket.DeploymentException; 032import javax.websocket.Endpoint; 033import javax.websocket.EndpointConfig; 034import javax.websocket.Extension; 035import javax.websocket.WebSocketContainer; 036 037import org.glassfish.tyrus.core.ComponentProviderService; 038import org.glassfish.tyrus.core.RequestContext; 039import org.glassfish.tyrus.core.TyrusEndpointWrapper; 040import org.glassfish.tyrus.core.TyrusUpgradeResponse; 041import org.glassfish.tyrus.core.TyrusWebSocketEngine; 042import org.glassfish.tyrus.core.Version; 043import org.glassfish.tyrus.server.TyrusServerContainer; 044import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; 045 046import org.springframework.beans.DirectFieldAccessor; 047import org.springframework.http.HttpHeaders; 048import org.springframework.http.server.ServerHttpRequest; 049import org.springframework.http.server.ServerHttpResponse; 050import org.springframework.lang.Nullable; 051import org.springframework.util.ReflectionUtils; 052import org.springframework.util.StringUtils; 053import org.springframework.web.socket.WebSocketExtension; 054import org.springframework.web.socket.server.HandshakeFailureException; 055 056import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.SUCCESS; 057 058/** 059 * A base class for {@code RequestUpgradeStrategy} implementations on top of 060 * JSR-356 based servers which include Tyrus as their WebSocket engine. 061 * 062 * <p>Works with Tyrus 1.11 (WebLogic 12.2.1) and Tyrus 1.12 (GlassFish 4.1.1). 063 * 064 * @author Rossen Stoyanchev 065 * @author Brian Clozel 066 * @author Juergen Hoeller 067 * @since 4.1 068 * @see <a href="https://tyrus.java.net/">Project Tyrus</a> 069 */ 070public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { 071 072 private static final Random random = new Random(); 073 074 private static final Constructor<?> constructor; 075 076 private static boolean constructorWithBooleanArgument; 077 078 private static final Method registerMethod; 079 080 private static final Method unRegisterMethod; 081 082 static { 083 try { 084 constructor = getEndpointConstructor(); 085 int parameterCount = constructor.getParameterCount(); 086 constructorWithBooleanArgument = (parameterCount == 10); 087 if (!constructorWithBooleanArgument && parameterCount != 9) { 088 throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments"); 089 } 090 registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class); 091 unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class); 092 ReflectionUtils.makeAccessible(registerMethod); 093 } 094 catch (Exception ex) { 095 throw new IllegalStateException("No compatible Tyrus version found", ex); 096 } 097 } 098 099 private static Constructor<?> getEndpointConstructor() { 100 for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) { 101 Class<?>[] types = current.getParameterTypes(); 102 if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) { 103 return current; 104 } 105 } 106 throw new IllegalStateException("No compatible Tyrus version found"); 107 } 108 109 110 private final ComponentProviderService componentProvider = ComponentProviderService.create(); 111 112 113 @Override 114 public String[] getSupportedVersions() { 115 return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ","); 116 } 117 118 @Override 119 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) { 120 try { 121 return super.getInstalledExtensions(container); 122 } 123 catch (UnsupportedOperationException ex) { 124 return new ArrayList<>(0); 125 } 126 } 127 128 @Override 129 public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 130 @Nullable String selectedProtocol, List<Extension> extensions, Endpoint endpoint) 131 throws HandshakeFailureException { 132 133 HttpServletRequest servletRequest = getHttpServletRequest(request); 134 HttpServletResponse servletResponse = getHttpServletResponse(response); 135 136 TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest); 137 TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine(); 138 Object tyrusEndpoint = null; 139 boolean success; 140 141 try { 142 // Shouldn't matter for processing but must be unique 143 String path = "/" + random.nextLong(); 144 tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine); 145 register(engine, tyrusEndpoint); 146 147 HttpHeaders headers = request.getHeaders(); 148 RequestContext requestContext = createRequestContext(servletRequest, path, headers); 149 TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); 150 UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); 151 success = SUCCESS.equals(upgradeInfo.getStatus()); 152 if (success) { 153 if (logger.isTraceEnabled()) { 154 logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders()); 155 } 156 handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse); 157 } 158 } 159 catch (Exception ex) { 160 unregisterTyrusEndpoint(engine, tyrusEndpoint); 161 throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex); 162 } 163 164 unregisterTyrusEndpoint(engine, tyrusEndpoint); 165 if (!success) { 166 throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI()); 167 } 168 } 169 170 private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, @Nullable String protocol, 171 List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine) 172 throws DeploymentException { 173 174 ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); 175 endpointConfig.setSubprotocols(Collections.singletonList(protocol)); 176 endpointConfig.setExtensions(extensions); 177 return createEndpoint(endpointConfig, this.componentProvider, container, engine); 178 } 179 180 private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) { 181 RequestContext context = 182 RequestContext.Builder.create() 183 .requestURI(URI.create(endpointPath)) 184 .userPrincipal(request.getUserPrincipal()) 185 .secure(request.isSecure()) 186 .remoteAddr(request.getRemoteAddr()) 187 .build(); 188 headers.forEach((header, value) -> context.getHeaders().put(header, value)); 189 return context; 190 } 191 192 private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, @Nullable Object tyrusEndpoint) { 193 if (tyrusEndpoint != null) { 194 try { 195 unregister(engine, tyrusEndpoint); 196 } 197 catch (Throwable ex) { 198 // ignore 199 } 200 } 201 } 202 203 private Object createEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, 204 WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException { 205 206 DirectFieldAccessor accessor = new DirectFieldAccessor(engine); 207 Object sessionListener = accessor.getPropertyValue("sessionListener"); 208 Object clusterContext = accessor.getPropertyValue("clusterContext"); 209 try { 210 if (constructorWithBooleanArgument) { 211 // Tyrus 1.11+ 212 return constructor.newInstance(registration.getEndpoint(), registration, provider, container, 213 "/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE); 214 } 215 else { 216 return constructor.newInstance(registration.getEndpoint(), registration, provider, container, 217 "/", registration.getConfigurator(), sessionListener, clusterContext, null); 218 } 219 } 220 catch (Exception ex) { 221 throw new HandshakeFailureException("Failed to register " + registration, ex); 222 } 223 } 224 225 private void register(TyrusWebSocketEngine engine, Object endpoint) { 226 try { 227 registerMethod.invoke(engine, endpoint); 228 } 229 catch (Exception ex) { 230 throw new HandshakeFailureException("Failed to register " + endpoint, ex); 231 } 232 } 233 234 private void unregister(TyrusWebSocketEngine engine, Object endpoint) { 235 try { 236 unRegisterMethod.invoke(engine, endpoint); 237 } 238 catch (Exception ex) { 239 throw new HandshakeFailureException("Failed to unregister " + endpoint, ex); 240 } 241 } 242 243 244 protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response, 245 UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException; 246 247}