001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.io.IOException;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.Method;
022
023import javax.servlet.AsyncContext;
024import javax.servlet.ServletContext;
025import javax.servlet.ServletException;
026import javax.servlet.ServletRequest;
027import javax.servlet.ServletRequestWrapper;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpServletResponse;
030
031import org.glassfish.tyrus.core.TyrusUpgradeResponse;
032import org.glassfish.tyrus.core.Utils;
033import org.glassfish.tyrus.spi.Connection;
034import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;
035import org.glassfish.tyrus.spi.Writer;
036
037import org.springframework.beans.BeanWrapper;
038import org.springframework.beans.BeanWrapperImpl;
039import org.springframework.lang.Nullable;
040import org.springframework.util.ReflectionUtils;
041import org.springframework.web.socket.server.HandshakeFailureException;
042
043/**
044 * A WebSocket {@code RequestUpgradeStrategy} for Oracle's WebLogic.
045 * Supports 12.1.3 as well as 12.2.1, as of Spring Framework 4.2.3.
046 *
047 * @author Rossen Stoyanchev
048 * @author Juergen Hoeller
049 * @since 4.1
050 */
051public class WebLogicRequestUpgradeStrategy extends AbstractTyrusRequestUpgradeStrategy {
052
053        private static final TyrusMuxableWebSocketHelper webSocketHelper = new TyrusMuxableWebSocketHelper();
054
055        private static final WebLogicServletWriterHelper servletWriterHelper = new WebLogicServletWriterHelper();
056
057        private static final Connection.CloseListener noOpCloseListener = (reason -> {});
058
059
060        @Override
061        protected void handleSuccess(HttpServletRequest request, HttpServletResponse response,
062                        UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException {
063
064                response.setStatus(upgradeResponse.getStatus());
065                upgradeResponse.getHeaders().forEach((key, value) -> response.addHeader(key, Utils.getHeaderFromList(value)));
066
067                AsyncContext asyncContext = request.startAsync();
068                asyncContext.setTimeout(-1L);
069
070                Object nativeRequest = getNativeRequest(request);
071                BeanWrapper beanWrapper = new BeanWrapperImpl(nativeRequest);
072                Object httpSocket = beanWrapper.getPropertyValue("connection.connectionHandler.rawConnection");
073                Object webSocket = webSocketHelper.newInstance(request, httpSocket);
074                webSocketHelper.upgrade(webSocket, httpSocket, request.getServletContext());
075
076                response.flushBuffer();
077
078                boolean isProtected = request.getUserPrincipal() != null;
079                Writer servletWriter = servletWriterHelper.newInstance(webSocket, isProtected);
080                Connection connection = upgradeInfo.createConnection(servletWriter, noOpCloseListener);
081                new BeanWrapperImpl(webSocket).setPropertyValue("connection", connection);
082                new BeanWrapperImpl(servletWriter).setPropertyValue("connection", connection);
083                webSocketHelper.registerForReadEvent(webSocket);
084        }
085
086
087        private static Class<?> type(String className) throws ClassNotFoundException {
088                return WebLogicRequestUpgradeStrategy.class.getClassLoader().loadClass(className);
089        }
090
091        private static Method method(String className, String method, Class<?>... paramTypes)
092                        throws ClassNotFoundException, NoSuchMethodException {
093
094                return type(className).getDeclaredMethod(method, paramTypes);
095        }
096
097        private static Object getNativeRequest(ServletRequest request) {
098                while (request instanceof ServletRequestWrapper) {
099                        request = ((ServletRequestWrapper) request).getRequest();
100                }
101                return request;
102        }
103
104
105        /**
106         * Helps to create and invoke {@code weblogic.servlet.internal.MuxableSocketHTTP}.
107         */
108        private static class TyrusMuxableWebSocketHelper {
109
110                private static final Class<?> type;
111
112                private static final Constructor<?> constructor;
113
114                private static final SubjectHelper subjectHelper;
115
116                private static final Method upgradeMethod;
117
118                private static final Method readEventMethod;
119
120                static {
121                        try {
122                                type = type("weblogic.websocket.tyrus.TyrusMuxableWebSocket");
123
124                                constructor = type.getDeclaredConstructor(
125                                                type("weblogic.servlet.internal.MuxableSocketHTTP"),
126                                                type("weblogic.websocket.tyrus.CoherenceServletFilterService"),
127                                                type("weblogic.servlet.spi.SubjectHandle"));
128                                subjectHelper = new SubjectHelper();
129
130                                upgradeMethod = type.getMethod("upgrade", type("weblogic.socket.MuxableSocket"), ServletContext.class);
131                                readEventMethod = type.getMethod("registerForReadEvent");
132                        }
133                        catch (Exception ex) {
134                                throw new IllegalStateException("No compatible WebSocket version found", ex);
135                        }
136                }
137
138                private Object newInstance(HttpServletRequest request, @Nullable Object httpSocket) {
139                        try {
140                                Object[] args = new Object[] {httpSocket, null, subjectHelper.getSubject(request)};
141                                return constructor.newInstance(args);
142                        }
143                        catch (Exception ex) {
144                                throw new HandshakeFailureException("Failed to create TyrusMuxableWebSocket", ex);
145                        }
146                }
147
148                private void upgrade(Object webSocket, @Nullable Object httpSocket, ServletContext servletContext) {
149                        try {
150                                upgradeMethod.invoke(webSocket, httpSocket, servletContext);
151                        }
152                        catch (Exception ex) {
153                                throw new HandshakeFailureException("Failed to upgrade TyrusMuxableWebSocket", ex);
154                        }
155                }
156
157                private void registerForReadEvent(Object webSocket) {
158                        try {
159                                readEventMethod.invoke(webSocket);
160                        }
161                        catch (Exception ex) {
162                                throw new HandshakeFailureException("Failed to register WebSocket for read event", ex);
163                        }
164                }
165        }
166
167
168        private static class SubjectHelper {
169
170                private final Method securityContextMethod;
171
172                private final Method currentUserMethod;
173
174                private final Method providerMethod;
175
176                private final Method anonymousSubjectMethod;
177
178                public SubjectHelper() {
179                        try {
180                                String className = "weblogic.servlet.internal.WebAppServletContext";
181                                this.securityContextMethod = method(className, "getSecurityContext");
182
183                                className = "weblogic.servlet.security.internal.SecurityModule";
184                                this.currentUserMethod = method(className, "getCurrentUser",
185                                                type("weblogic.servlet.security.internal.ServletSecurityContext"),
186                                                HttpServletRequest.class);
187
188                                className = "weblogic.servlet.security.internal.WebAppSecurity";
189                                this.providerMethod = method(className, "getProvider");
190                                this.anonymousSubjectMethod = this.providerMethod.getReturnType().getDeclaredMethod("getAnonymousSubject");
191                        }
192                        catch (Exception ex) {
193                                throw new IllegalStateException("No compatible WebSocket version found", ex);
194                        }
195                }
196
197                public Object getSubject(HttpServletRequest request) {
198                        try {
199                                ServletContext servletContext = request.getServletContext();
200                                Object securityContext = this.securityContextMethod.invoke(servletContext);
201                                Object subject = this.currentUserMethod.invoke(null, securityContext, request);
202                                if (subject == null) {
203                                        Object securityProvider = this.providerMethod.invoke(null);
204                                        subject = this.anonymousSubjectMethod.invoke(securityProvider);
205                                }
206                                return subject;
207                        }
208                        catch (Exception ex) {
209                                throw new HandshakeFailureException("Failed to obtain SubjectHandle", ex);
210                        }
211                }
212        }
213
214
215        /**
216         * Helps to create and invoke {@code weblogic.websocket.tyrus.TyrusServletWriter}.
217         */
218        private static class WebLogicServletWriterHelper {
219
220                private static final Constructor<?> constructor;
221
222                static {
223                        try {
224                                Class<?> writerType = type("weblogic.websocket.tyrus.TyrusServletWriter");
225                                Class<?> listenerType = type("weblogic.websocket.tyrus.TyrusServletWriter$CloseListener");
226                                Class<?> webSocketType = TyrusMuxableWebSocketHelper.type;
227                                constructor = writerType.getDeclaredConstructor(webSocketType, listenerType, boolean.class);
228                                ReflectionUtils.makeAccessible(constructor);
229                        }
230                        catch (Exception ex) {
231                                throw new IllegalStateException("No compatible WebSocket version found", ex);
232                        }
233                }
234
235                private Writer newInstance(Object webSocket, boolean isProtected) {
236                        try {
237                                return (Writer) constructor.newInstance(webSocket, null, isProtected);
238                        }
239                        catch (Exception ex) {
240                                throw new HandshakeFailureException("Failed to create TyrusServletWriter", ex);
241                        }
242                }
243        }
244
245}