001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.server.upgrade;
018
019import java.util.Collections;
020import java.util.List;
021import java.util.Set;
022import java.util.function.Supplier;
023
024import io.undertow.server.HttpServerExchange;
025import io.undertow.websockets.WebSocketConnectionCallback;
026import io.undertow.websockets.WebSocketProtocolHandshakeHandler;
027import io.undertow.websockets.core.WebSocketChannel;
028import io.undertow.websockets.core.protocol.Handshake;
029import io.undertow.websockets.core.protocol.version13.Hybi13Handshake;
030import io.undertow.websockets.spi.WebSocketHttpExchange;
031import reactor.core.publisher.Mono;
032
033import org.springframework.core.io.buffer.DataBufferFactory;
034import org.springframework.http.server.reactive.AbstractServerHttpRequest;
035import org.springframework.http.server.reactive.ServerHttpRequest;
036import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
037import org.springframework.lang.Nullable;
038import org.springframework.web.reactive.socket.HandshakeInfo;
039import org.springframework.web.reactive.socket.WebSocketHandler;
040import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
041import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
042import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
043import org.springframework.web.server.ServerWebExchange;
044
045/**
046 * A {@link RequestUpgradeStrategy} for use with Undertow.
047 *
048 * @author Violeta Georgieva
049 * @author Rossen Stoyanchev
050 * @author Brian Clozel
051 * @since 5.0
052 */
053public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
054
055        @Override
056        public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
057                        @Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
058
059                HttpServerExchange httpExchange = getNativeRequest(exchange.getRequest());
060
061                Set<String> protocols = (subProtocol != null ? Collections.singleton(subProtocol) : Collections.emptySet());
062                Hybi13Handshake handshake = new Hybi13Handshake(protocols, false);
063                List<Handshake> handshakes = Collections.singletonList(handshake);
064
065                HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
066                DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
067
068                // Trigger WebFlux preCommit actions and upgrade
069                return exchange.getResponse().setComplete()
070                                .then(Mono.fromCallable(() -> {
071                                        DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory);
072                                        new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
073                                        return null;
074                                }));
075        }
076
077        private static HttpServerExchange getNativeRequest(ServerHttpRequest request) {
078                if (request instanceof AbstractServerHttpRequest) {
079                        return ((AbstractServerHttpRequest) request).getNativeRequest();
080                }
081                else if (request instanceof ServerHttpRequestDecorator) {
082                        return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
083                }
084                else {
085                        throw new IllegalArgumentException(
086                                        "Couldn't find HttpServerExchange in " + request.getClass().getName());
087                }
088        }
089
090
091        private class DefaultCallback implements WebSocketConnectionCallback {
092
093                private final HandshakeInfo handshakeInfo;
094
095                private final WebSocketHandler handler;
096
097                private final DataBufferFactory bufferFactory;
098
099                public DefaultCallback(HandshakeInfo handshakeInfo, WebSocketHandler handler, DataBufferFactory bufferFactory) {
100                        this.handshakeInfo = handshakeInfo;
101                        this.handler = handler;
102                        this.bufferFactory = bufferFactory;
103                }
104
105                @Override
106                public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) {
107                        UndertowWebSocketSession session = createSession(channel);
108                        UndertowWebSocketHandlerAdapter adapter = new UndertowWebSocketHandlerAdapter(session);
109
110                        channel.getReceiveSetter().set(adapter);
111                        channel.resumeReceives();
112
113                        this.handler.handle(session)
114                                        .checkpoint(exchange.getRequestURI() + " [UndertowRequestUpgradeStrategy]")
115                                        .subscribe(session);
116                }
117
118                private UndertowWebSocketSession createSession(WebSocketChannel channel) {
119                        return new UndertowWebSocketSession(channel, this.handshakeInfo, this.bufferFactory);
120                }
121        }
122
123}