001/*
002 * Copyright 2002-2016 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.io.IOException;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.Method;
022import java.util.List;
023import java.util.Map;
024import javax.servlet.AsyncContext;
025import javax.servlet.ServletContext;
026import javax.servlet.ServletException;
027import javax.servlet.ServletRequest;
028import javax.servlet.ServletRequestWrapper;
029import javax.servlet.http.HttpServletRequest;
030import javax.servlet.http.HttpServletResponse;
031import javax.websocket.CloseReason;
032
033import org.glassfish.tyrus.core.TyrusUpgradeResponse;
034import org.glassfish.tyrus.core.Utils;
035import org.glassfish.tyrus.spi.Connection;
036import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;
037import org.glassfish.tyrus.spi.Writer;
038
039import org.springframework.beans.BeanWrapper;
040import org.springframework.beans.BeanWrapperImpl;
041import org.springframework.util.ReflectionUtils;
042import org.springframework.web.socket.server.HandshakeFailureException;
043
044/**
045 * A WebSocket {@code RequestUpgradeStrategy} for Oracle's WebLogic.
046 * Supports 12.1.3 as well as 12.2.1, as of Spring Framework 4.2.3.
047 *
048 * @author Rossen Stoyanchev
049 * @since 4.1
050 */
051public class WebLogicRequestUpgradeStrategy extends AbstractTyrusRequestUpgradeStrategy {
052
053        private static final boolean WLS_12_1_3 = isWebLogic1213();
054
055        private static final TyrusEndpointHelper endpointHelper =
056                        (WLS_12_1_3 ? new Tyrus135EndpointHelper() : new Tyrus17EndpointHelper());
057
058        private static final TyrusMuxableWebSocketHelper webSocketHelper = new TyrusMuxableWebSocketHelper();
059
060        private static final WebLogicServletWriterHelper servletWriterHelper = new WebLogicServletWriterHelper();
061
062        private static final Connection.CloseListener noOpCloseListener = new Connection.CloseListener() {
063                @Override
064                public void close(CloseReason reason) {
065                }
066        };
067
068
069
070        @Override
071        protected TyrusEndpointHelper getEndpointHelper() {
072                return endpointHelper;
073        }
074
075        @Override
076        protected void handleSuccess(HttpServletRequest request, HttpServletResponse response,
077                        UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException {
078
079                response.setStatus(upgradeResponse.getStatus());
080                for (Map.Entry<String, List<String>> entry : upgradeResponse.getHeaders().entrySet()) {
081                        response.addHeader(entry.getKey(), Utils.getHeaderFromList(entry.getValue()));
082                }
083
084                AsyncContext asyncContext = request.startAsync();
085                asyncContext.setTimeout(-1L);
086
087                Object nativeRequest = getNativeRequest(request);
088                BeanWrapper beanWrapper = new BeanWrapperImpl(nativeRequest);
089                Object httpSocket = beanWrapper.getPropertyValue("connection.connectionHandler.rawConnection");
090                Object webSocket = webSocketHelper.newInstance(request, httpSocket);
091                webSocketHelper.upgrade(webSocket, httpSocket, request.getServletContext());
092
093                response.flushBuffer();
094
095                boolean isProtected = request.getUserPrincipal() != null;
096                Writer servletWriter = servletWriterHelper.newInstance(response, webSocket, isProtected);
097                Connection connection = upgradeInfo.createConnection(servletWriter, noOpCloseListener);
098                new BeanWrapperImpl(webSocket).setPropertyValue("connection", connection);
099                new BeanWrapperImpl(servletWriter).setPropertyValue("connection", connection);
100                webSocketHelper.registerForReadEvent(webSocket);
101        }
102
103
104        private static boolean isWebLogic1213() {
105                try {
106                        type("weblogic.websocket.tyrus.TyrusMuxableWebSocket").getDeclaredConstructor(
107                                        type("weblogic.servlet.internal.MuxableSocketHTTP"));
108                        return true;
109                }
110                catch (NoSuchMethodException ex) {
111                        return false;
112                }
113                catch (ClassNotFoundException ex) {
114                        throw new IllegalStateException("No compatible WebSocket version found", ex);
115                }
116        }
117
118        private static Class<?> type(String className) throws ClassNotFoundException {
119                return WebLogicRequestUpgradeStrategy.class.getClassLoader().loadClass(className);
120        }
121
122        private static Method method(String className, String method, Class<?>... paramTypes)
123                        throws ClassNotFoundException, NoSuchMethodException {
124
125                return type(className).getDeclaredMethod(method, paramTypes);
126        }
127
128        private static Object getNativeRequest(ServletRequest request) {
129                while (request instanceof ServletRequestWrapper) {
130                        request = ((ServletRequestWrapper) request).getRequest();
131                }
132                return request;
133        }
134
135
136        /**
137         * Helps to create and invoke {@code weblogic.servlet.internal.MuxableSocketHTTP}.
138         */
139        private static class TyrusMuxableWebSocketHelper {
140
141                private static final Class<?> type;
142
143                private static final Constructor<?> constructor;
144
145                private static final SubjectHelper subjectHelper;
146
147                private static final Method upgradeMethod;
148
149                private static final Method readEventMethod;
150
151                static {
152                        try {
153                                type = type("weblogic.websocket.tyrus.TyrusMuxableWebSocket");
154
155                                if (WLS_12_1_3) {
156                                        constructor = type.getDeclaredConstructor(type("weblogic.servlet.internal.MuxableSocketHTTP"));
157                                        subjectHelper = null;
158                                }
159                                else {
160                                        constructor = type.getDeclaredConstructor(
161                                                        type("weblogic.servlet.internal.MuxableSocketHTTP"),
162                                                        type("weblogic.websocket.tyrus.CoherenceServletFilterService"),
163                                                        type("weblogic.servlet.spi.SubjectHandle"));
164                                        subjectHelper = new SubjectHelper();
165                                }
166
167                                upgradeMethod = type.getMethod("upgrade", type("weblogic.socket.MuxableSocket"), ServletContext.class);
168                                readEventMethod = type.getMethod("registerForReadEvent");
169                        }
170                        catch (Exception ex) {
171                                throw new IllegalStateException("No compatible WebSocket version found", ex);
172                        }
173                }
174
175                private Object newInstance(HttpServletRequest request, Object httpSocket) {
176                        try {
177                                Object[] args = (WLS_12_1_3 ? new Object[] {httpSocket} :
178                                                new Object[] {httpSocket, null, subjectHelper.getSubject(request)});
179                                return constructor.newInstance(args);
180                        }
181                        catch (Exception ex) {
182                                throw new HandshakeFailureException("Failed to create TyrusMuxableWebSocket", ex);
183                        }
184                }
185
186                private void upgrade(Object webSocket, Object httpSocket, ServletContext servletContext) {
187                        try {
188                                upgradeMethod.invoke(webSocket, httpSocket, servletContext);
189                        }
190                        catch (Exception ex) {
191                                throw new HandshakeFailureException("Failed to upgrade TyrusMuxableWebSocket", ex);
192                        }
193                }
194
195                private void registerForReadEvent(Object webSocket) {
196                        try {
197                                readEventMethod.invoke(webSocket);
198                        }
199                        catch (Exception ex) {
200                                throw new HandshakeFailureException("Failed to register WebSocket for read event", ex);
201                        }
202                }
203        }
204
205
206        private static class SubjectHelper {
207
208                private final Method securityContextMethod;
209
210                private final Method currentUserMethod;
211
212                private final Method providerMethod;
213
214                private final Method anonymousSubjectMethod;
215
216                public SubjectHelper() {
217                        try {
218                                String className = "weblogic.servlet.internal.WebAppServletContext";
219                                securityContextMethod = method(className, "getSecurityContext");
220
221                                className = "weblogic.servlet.security.internal.SecurityModule";
222                                currentUserMethod = method(className, "getCurrentUser",
223                                                type("weblogic.servlet.security.internal.ServletSecurityContext"),
224                                                HttpServletRequest.class);
225
226                                className = "weblogic.servlet.security.internal.WebAppSecurity";
227                                providerMethod = method(className, "getProvider");
228                                anonymousSubjectMethod = providerMethod.getReturnType().getDeclaredMethod("getAnonymousSubject");
229                        }
230                        catch (Exception ex) {
231                                throw new IllegalStateException("No compatible WebSocket version found", ex);
232                        }
233                }
234
235                public Object getSubject(HttpServletRequest request) {
236                        try {
237                                ServletContext servletContext = request.getServletContext();
238                                Object securityContext = securityContextMethod.invoke(servletContext);
239                                Object subject = currentUserMethod.invoke(null, securityContext, request);
240                                if (subject == null) {
241                                        Object securityProvider = providerMethod.invoke(null);
242                                        subject = anonymousSubjectMethod.invoke(securityProvider);
243                                }
244                                return subject;
245                        }
246                        catch (Exception ex) {
247                                throw new HandshakeFailureException("Failed to obtain SubjectHandle", ex);
248                        }
249                }
250        }
251
252
253        /**
254         * Helps to create and invoke {@code weblogic.websocket.tyrus.TyrusServletWriter}.
255         */
256        private static class WebLogicServletWriterHelper {
257
258                private static final Constructor<?> constructor;
259
260                static {
261                        try {
262                                Class<?> writerType = type("weblogic.websocket.tyrus.TyrusServletWriter");
263                                Class<?> listenerType = type("weblogic.websocket.tyrus.TyrusServletWriter$CloseListener");
264                                Class<?> webSocketType = TyrusMuxableWebSocketHelper.type;
265                                Class<HttpServletResponse> responseType = HttpServletResponse.class;
266
267                                Class<?>[] argTypes = (WLS_12_1_3 ?
268                                                new Class<?>[] {webSocketType, responseType, listenerType, boolean.class} :
269                                                new Class<?>[] {webSocketType, listenerType, boolean.class});
270
271                                constructor = writerType.getDeclaredConstructor(argTypes);
272                                ReflectionUtils.makeAccessible(constructor);
273                        }
274                        catch (Exception ex) {
275                                throw new IllegalStateException("No compatible WebSocket version found", ex);
276                        }
277                }
278
279                private Writer newInstance(HttpServletResponse response, Object webSocket, boolean isProtected) {
280                        try {
281                                Object[] args = (WLS_12_1_3 ?
282                                                new Object[] {webSocket, response, null, isProtected} :
283                                                new Object[] {webSocket, null, isProtected});
284
285                                return (Writer) constructor.newInstance(args);
286                        }
287                        catch (Exception ex) {
288                                throw new HandshakeFailureException("Failed to create TyrusServletWriter", ex);
289                        }
290                }
291        }
292
293}