001/* 002 * Copyright 2002-2016 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.io.IOException; 020import java.lang.reflect.Constructor; 021import java.lang.reflect.Method; 022import java.util.List; 023import java.util.Map; 024import javax.servlet.AsyncContext; 025import javax.servlet.ServletContext; 026import javax.servlet.ServletException; 027import javax.servlet.ServletRequest; 028import javax.servlet.ServletRequestWrapper; 029import javax.servlet.http.HttpServletRequest; 030import javax.servlet.http.HttpServletResponse; 031import javax.websocket.CloseReason; 032 033import org.glassfish.tyrus.core.TyrusUpgradeResponse; 034import org.glassfish.tyrus.core.Utils; 035import org.glassfish.tyrus.spi.Connection; 036import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; 037import org.glassfish.tyrus.spi.Writer; 038 039import org.springframework.beans.BeanWrapper; 040import org.springframework.beans.BeanWrapperImpl; 041import org.springframework.util.ReflectionUtils; 042import org.springframework.web.socket.server.HandshakeFailureException; 043 044/** 045 * A WebSocket {@code RequestUpgradeStrategy} for Oracle's WebLogic. 046 * Supports 12.1.3 as well as 12.2.1, as of Spring Framework 4.2.3. 047 * 048 * @author Rossen Stoyanchev 049 * @since 4.1 050 */ 051public class WebLogicRequestUpgradeStrategy extends AbstractTyrusRequestUpgradeStrategy { 052 053 private static final boolean WLS_12_1_3 = isWebLogic1213(); 054 055 private static final TyrusEndpointHelper endpointHelper = 056 (WLS_12_1_3 ? new Tyrus135EndpointHelper() : new Tyrus17EndpointHelper()); 057 058 private static final TyrusMuxableWebSocketHelper webSocketHelper = new TyrusMuxableWebSocketHelper(); 059 060 private static final WebLogicServletWriterHelper servletWriterHelper = new WebLogicServletWriterHelper(); 061 062 private static final Connection.CloseListener noOpCloseListener = new Connection.CloseListener() { 063 @Override 064 public void close(CloseReason reason) { 065 } 066 }; 067 068 069 070 @Override 071 protected TyrusEndpointHelper getEndpointHelper() { 072 return endpointHelper; 073 } 074 075 @Override 076 protected void handleSuccess(HttpServletRequest request, HttpServletResponse response, 077 UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException { 078 079 response.setStatus(upgradeResponse.getStatus()); 080 for (Map.Entry<String, List<String>> entry : upgradeResponse.getHeaders().entrySet()) { 081 response.addHeader(entry.getKey(), Utils.getHeaderFromList(entry.getValue())); 082 } 083 084 AsyncContext asyncContext = request.startAsync(); 085 asyncContext.setTimeout(-1L); 086 087 Object nativeRequest = getNativeRequest(request); 088 BeanWrapper beanWrapper = new BeanWrapperImpl(nativeRequest); 089 Object httpSocket = beanWrapper.getPropertyValue("connection.connectionHandler.rawConnection"); 090 Object webSocket = webSocketHelper.newInstance(request, httpSocket); 091 webSocketHelper.upgrade(webSocket, httpSocket, request.getServletContext()); 092 093 response.flushBuffer(); 094 095 boolean isProtected = request.getUserPrincipal() != null; 096 Writer servletWriter = servletWriterHelper.newInstance(response, webSocket, isProtected); 097 Connection connection = upgradeInfo.createConnection(servletWriter, noOpCloseListener); 098 new BeanWrapperImpl(webSocket).setPropertyValue("connection", connection); 099 new BeanWrapperImpl(servletWriter).setPropertyValue("connection", connection); 100 webSocketHelper.registerForReadEvent(webSocket); 101 } 102 103 104 private static boolean isWebLogic1213() { 105 try { 106 type("weblogic.websocket.tyrus.TyrusMuxableWebSocket").getDeclaredConstructor( 107 type("weblogic.servlet.internal.MuxableSocketHTTP")); 108 return true; 109 } 110 catch (NoSuchMethodException ex) { 111 return false; 112 } 113 catch (ClassNotFoundException ex) { 114 throw new IllegalStateException("No compatible WebSocket version found", ex); 115 } 116 } 117 118 private static Class<?> type(String className) throws ClassNotFoundException { 119 return WebLogicRequestUpgradeStrategy.class.getClassLoader().loadClass(className); 120 } 121 122 private static Method method(String className, String method, Class<?>... paramTypes) 123 throws ClassNotFoundException, NoSuchMethodException { 124 125 return type(className).getDeclaredMethod(method, paramTypes); 126 } 127 128 private static Object getNativeRequest(ServletRequest request) { 129 while (request instanceof ServletRequestWrapper) { 130 request = ((ServletRequestWrapper) request).getRequest(); 131 } 132 return request; 133 } 134 135 136 /** 137 * Helps to create and invoke {@code weblogic.servlet.internal.MuxableSocketHTTP}. 138 */ 139 private static class TyrusMuxableWebSocketHelper { 140 141 private static final Class<?> type; 142 143 private static final Constructor<?> constructor; 144 145 private static final SubjectHelper subjectHelper; 146 147 private static final Method upgradeMethod; 148 149 private static final Method readEventMethod; 150 151 static { 152 try { 153 type = type("weblogic.websocket.tyrus.TyrusMuxableWebSocket"); 154 155 if (WLS_12_1_3) { 156 constructor = type.getDeclaredConstructor(type("weblogic.servlet.internal.MuxableSocketHTTP")); 157 subjectHelper = null; 158 } 159 else { 160 constructor = type.getDeclaredConstructor( 161 type("weblogic.servlet.internal.MuxableSocketHTTP"), 162 type("weblogic.websocket.tyrus.CoherenceServletFilterService"), 163 type("weblogic.servlet.spi.SubjectHandle")); 164 subjectHelper = new SubjectHelper(); 165 } 166 167 upgradeMethod = type.getMethod("upgrade", type("weblogic.socket.MuxableSocket"), ServletContext.class); 168 readEventMethod = type.getMethod("registerForReadEvent"); 169 } 170 catch (Exception ex) { 171 throw new IllegalStateException("No compatible WebSocket version found", ex); 172 } 173 } 174 175 private Object newInstance(HttpServletRequest request, Object httpSocket) { 176 try { 177 Object[] args = (WLS_12_1_3 ? new Object[] {httpSocket} : 178 new Object[] {httpSocket, null, subjectHelper.getSubject(request)}); 179 return constructor.newInstance(args); 180 } 181 catch (Exception ex) { 182 throw new HandshakeFailureException("Failed to create TyrusMuxableWebSocket", ex); 183 } 184 } 185 186 private void upgrade(Object webSocket, Object httpSocket, ServletContext servletContext) { 187 try { 188 upgradeMethod.invoke(webSocket, httpSocket, servletContext); 189 } 190 catch (Exception ex) { 191 throw new HandshakeFailureException("Failed to upgrade TyrusMuxableWebSocket", ex); 192 } 193 } 194 195 private void registerForReadEvent(Object webSocket) { 196 try { 197 readEventMethod.invoke(webSocket); 198 } 199 catch (Exception ex) { 200 throw new HandshakeFailureException("Failed to register WebSocket for read event", ex); 201 } 202 } 203 } 204 205 206 private static class SubjectHelper { 207 208 private final Method securityContextMethod; 209 210 private final Method currentUserMethod; 211 212 private final Method providerMethod; 213 214 private final Method anonymousSubjectMethod; 215 216 public SubjectHelper() { 217 try { 218 String className = "weblogic.servlet.internal.WebAppServletContext"; 219 securityContextMethod = method(className, "getSecurityContext"); 220 221 className = "weblogic.servlet.security.internal.SecurityModule"; 222 currentUserMethod = method(className, "getCurrentUser", 223 type("weblogic.servlet.security.internal.ServletSecurityContext"), 224 HttpServletRequest.class); 225 226 className = "weblogic.servlet.security.internal.WebAppSecurity"; 227 providerMethod = method(className, "getProvider"); 228 anonymousSubjectMethod = providerMethod.getReturnType().getDeclaredMethod("getAnonymousSubject"); 229 } 230 catch (Exception ex) { 231 throw new IllegalStateException("No compatible WebSocket version found", ex); 232 } 233 } 234 235 public Object getSubject(HttpServletRequest request) { 236 try { 237 ServletContext servletContext = request.getServletContext(); 238 Object securityContext = securityContextMethod.invoke(servletContext); 239 Object subject = currentUserMethod.invoke(null, securityContext, request); 240 if (subject == null) { 241 Object securityProvider = providerMethod.invoke(null); 242 subject = anonymousSubjectMethod.invoke(securityProvider); 243 } 244 return subject; 245 } 246 catch (Exception ex) { 247 throw new HandshakeFailureException("Failed to obtain SubjectHandle", ex); 248 } 249 } 250 } 251 252 253 /** 254 * Helps to create and invoke {@code weblogic.websocket.tyrus.TyrusServletWriter}. 255 */ 256 private static class WebLogicServletWriterHelper { 257 258 private static final Constructor<?> constructor; 259 260 static { 261 try { 262 Class<?> writerType = type("weblogic.websocket.tyrus.TyrusServletWriter"); 263 Class<?> listenerType = type("weblogic.websocket.tyrus.TyrusServletWriter$CloseListener"); 264 Class<?> webSocketType = TyrusMuxableWebSocketHelper.type; 265 Class<HttpServletResponse> responseType = HttpServletResponse.class; 266 267 Class<?>[] argTypes = (WLS_12_1_3 ? 268 new Class<?>[] {webSocketType, responseType, listenerType, boolean.class} : 269 new Class<?>[] {webSocketType, listenerType, boolean.class}); 270 271 constructor = writerType.getDeclaredConstructor(argTypes); 272 ReflectionUtils.makeAccessible(constructor); 273 } 274 catch (Exception ex) { 275 throw new IllegalStateException("No compatible WebSocket version found", ex); 276 } 277 } 278 279 private Writer newInstance(HttpServletResponse response, Object webSocket, boolean isProtected) { 280 try { 281 Object[] args = (WLS_12_1_3 ? 282 new Object[] {webSocket, response, null, isProtected} : 283 new Object[] {webSocket, null, isProtected}); 284 285 return (Writer) constructor.newInstance(args); 286 } 287 catch (Exception ex) { 288 throw new HandshakeFailureException("Failed to create TyrusServletWriter", ex); 289 } 290 } 291 } 292 293}