001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.server.support;
018
019import java.net.InetSocketAddress;
020import java.net.URI;
021import java.security.Principal;
022import java.util.Collections;
023import java.util.List;
024import java.util.Map;
025import java.util.function.Predicate;
026import java.util.stream.Collectors;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030import reactor.core.publisher.Mono;
031
032import org.springframework.context.Lifecycle;
033import org.springframework.http.HttpHeaders;
034import org.springframework.http.HttpMethod;
035import org.springframework.http.server.reactive.ServerHttpRequest;
036import org.springframework.lang.Nullable;
037import org.springframework.util.Assert;
038import org.springframework.util.ClassUtils;
039import org.springframework.util.ReflectionUtils;
040import org.springframework.util.StringUtils;
041import org.springframework.web.reactive.socket.HandshakeInfo;
042import org.springframework.web.reactive.socket.WebSocketHandler;
043import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
044import org.springframework.web.reactive.socket.server.WebSocketService;
045import org.springframework.web.server.MethodNotAllowedException;
046import org.springframework.web.server.ServerWebExchange;
047import org.springframework.web.server.ServerWebInputException;
048
049/**
050 * {@code WebSocketService} implementation that handles a WebSocket HTTP
051 * handshake request by delegating to a {@link RequestUpgradeStrategy} which
052 * is either auto-detected (no-arg constructor) from the classpath but can
053 * also be explicitly configured.
054 *
055 * @author Rossen Stoyanchev
056 * @since 5.0
057 */
058public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
059
060        private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key";
061
062        private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
063
064        private static final Mono<Map<String, Object>> EMPTY_ATTRIBUTES = Mono.just(Collections.emptyMap());
065
066
067        private static final boolean tomcatPresent;
068
069        private static final boolean jettyPresent;
070
071        private static final boolean undertowPresent;
072
073        private static final boolean reactorNettyPresent;
074
075        static {
076                ClassLoader classLoader = HandshakeWebSocketService.class.getClassLoader();
077                tomcatPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader);
078                jettyPresent = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader);
079                undertowPresent = ClassUtils.isPresent("io.undertow.websockets.WebSocketProtocolHandshakeHandler", classLoader);
080                reactorNettyPresent = ClassUtils.isPresent("reactor.netty.http.server.HttpServerResponse", classLoader);
081        }
082
083
084        protected static final Log logger = LogFactory.getLog(HandshakeWebSocketService.class);
085
086
087        private final RequestUpgradeStrategy upgradeStrategy;
088
089        @Nullable
090        private Predicate<String> sessionAttributePredicate;
091
092        private volatile boolean running = false;
093
094
095        /**
096         * Default constructor automatic, classpath detection based discovery of the
097         * {@link RequestUpgradeStrategy} to use.
098         */
099        public HandshakeWebSocketService() {
100                this(initUpgradeStrategy());
101        }
102
103        /**
104         * Alternative constructor with the {@link RequestUpgradeStrategy} to use.
105         * @param upgradeStrategy the strategy to use
106         */
107        public HandshakeWebSocketService(RequestUpgradeStrategy upgradeStrategy) {
108                Assert.notNull(upgradeStrategy, "RequestUpgradeStrategy is required");
109                this.upgradeStrategy = upgradeStrategy;
110        }
111
112        private static RequestUpgradeStrategy initUpgradeStrategy() {
113                String className;
114                if (tomcatPresent) {
115                        className = "TomcatRequestUpgradeStrategy";
116                }
117                else if (jettyPresent) {
118                        className = "JettyRequestUpgradeStrategy";
119                }
120                else if (undertowPresent) {
121                        className = "UndertowRequestUpgradeStrategy";
122                }
123                else if (reactorNettyPresent) {
124                        // As late as possible (Reactor Netty commonly used for WebClient)
125                        className = "ReactorNettyRequestUpgradeStrategy";
126                }
127                else {
128                        throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
129                }
130
131                try {
132                        className = "org.springframework.web.reactive.socket.server.upgrade." + className;
133                        Class<?> clazz = ClassUtils.forName(className, HandshakeWebSocketService.class.getClassLoader());
134                        return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance();
135                }
136                catch (Throwable ex) {
137                        throw new IllegalStateException(
138                                        "Failed to instantiate RequestUpgradeStrategy: " + className, ex);
139                }
140        }
141
142
143        /**
144         * Return the {@link RequestUpgradeStrategy} for WebSocket requests.
145         */
146        public RequestUpgradeStrategy getUpgradeStrategy() {
147                return this.upgradeStrategy;
148        }
149
150        /**
151         * Configure a predicate to use to extract
152         * {@link org.springframework.web.server.WebSession WebSession} attributes
153         * and use them to initialize the WebSocket session with.
154         * <p>By default this is not set in which case no attributes are passed.
155         * @param predicate the predicate
156         * @since 5.1
157         */
158        public void setSessionAttributePredicate(@Nullable Predicate<String> predicate) {
159                this.sessionAttributePredicate = predicate;
160        }
161
162        /**
163         * Return the configured predicate for initialization WebSocket session
164         * attributes from {@code WebSession} attributes.
165         * @since 5.1
166         */
167        @Nullable
168        public Predicate<String> getSessionAttributePredicate() {
169                return this.sessionAttributePredicate;
170        }
171
172
173        @Override
174        public void start() {
175                if (!isRunning()) {
176                        this.running = true;
177                        doStart();
178                }
179        }
180
181        protected void doStart() {
182                if (getUpgradeStrategy() instanceof Lifecycle) {
183                        ((Lifecycle) getUpgradeStrategy()).start();
184                }
185        }
186
187        @Override
188        public void stop() {
189                if (isRunning()) {
190                        this.running = false;
191                        doStop();
192                }
193        }
194
195        protected void doStop() {
196                if (getUpgradeStrategy() instanceof Lifecycle) {
197                        ((Lifecycle) getUpgradeStrategy()).stop();
198                }
199        }
200
201        @Override
202        public boolean isRunning() {
203                return this.running;
204        }
205
206
207        @Override
208        public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) {
209                ServerHttpRequest request = exchange.getRequest();
210                HttpMethod method = request.getMethod();
211                HttpHeaders headers = request.getHeaders();
212
213                if (HttpMethod.GET != method) {
214                        return Mono.error(new MethodNotAllowedException(
215                                        request.getMethodValue(), Collections.singleton(HttpMethod.GET)));
216                }
217
218                if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
219                        return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
220                }
221
222                List<String> connectionValue = headers.getConnection();
223                if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
224                        return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
225                }
226
227                String key = headers.getFirst(SEC_WEBSOCKET_KEY);
228                if (key == null) {
229                        return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
230                }
231
232                String protocol = selectProtocol(headers, handler);
233
234                return initAttributes(exchange).flatMap(attributes ->
235                                this.upgradeStrategy.upgrade(exchange, handler, protocol,
236                                                () -> createHandshakeInfo(exchange, request, protocol, attributes))
237                );
238        }
239
240        private Mono<Void> handleBadRequest(ServerWebExchange exchange, String reason) {
241                if (logger.isDebugEnabled()) {
242                        logger.debug(exchange.getLogPrefix() + reason);
243                }
244                return Mono.error(new ServerWebInputException(reason));
245        }
246
247        @Nullable
248        private String selectProtocol(HttpHeaders headers, WebSocketHandler handler) {
249                String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
250                if (protocolHeader != null) {
251                        List<String> supportedProtocols = handler.getSubProtocols();
252                        for (String protocol : StringUtils.commaDelimitedListToStringArray(protocolHeader)) {
253                                if (supportedProtocols.contains(protocol)) {
254                                        return protocol;
255                                }
256                        }
257                }
258                return null;
259        }
260
261        private Mono<Map<String, Object>> initAttributes(ServerWebExchange exchange) {
262                if (this.sessionAttributePredicate == null) {
263                        return EMPTY_ATTRIBUTES;
264                }
265                return exchange.getSession().map(session ->
266                                session.getAttributes().entrySet().stream()
267                                                .filter(entry -> this.sessionAttributePredicate.test(entry.getKey()))
268                                                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
269        }
270
271        private HandshakeInfo createHandshakeInfo(ServerWebExchange exchange, ServerHttpRequest request,
272                        @Nullable String protocol, Map<String, Object> attributes) {
273
274                URI uri = request.getURI();
275                // Copy request headers, as they might be pooled and recycled by
276                // the server implementation once the handshake HTTP exchange is done.
277                HttpHeaders headers = new HttpHeaders();
278                headers.addAll(request.getHeaders());
279                Mono<Principal> principal = exchange.getPrincipal();
280                String logPrefix = exchange.getLogPrefix();
281                InetSocketAddress remoteAddress = request.getRemoteAddress();
282                return new HandshakeInfo(uri, headers, principal, protocol, remoteAddress, attributes, logPrefix);
283        }
284
285}