001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.server.upgrade;
018
019import java.util.Collections;
020import java.util.function.Supplier;
021
022import javax.servlet.http.HttpServletRequest;
023import javax.servlet.http.HttpServletResponse;
024import javax.websocket.Endpoint;
025import javax.websocket.server.ServerContainer;
026
027import org.apache.tomcat.websocket.server.WsServerContainer;
028import reactor.core.publisher.Mono;
029
030import org.springframework.core.io.buffer.DataBufferFactory;
031import org.springframework.http.server.reactive.AbstractServerHttpRequest;
032import org.springframework.http.server.reactive.AbstractServerHttpResponse;
033import org.springframework.http.server.reactive.ServerHttpRequest;
034import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
035import org.springframework.http.server.reactive.ServerHttpResponse;
036import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
037import org.springframework.lang.Nullable;
038import org.springframework.util.Assert;
039import org.springframework.web.reactive.socket.HandshakeInfo;
040import org.springframework.web.reactive.socket.WebSocketHandler;
041import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
042import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession;
043import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
044import org.springframework.web.server.ServerWebExchange;
045
046/**
047 * A {@link RequestUpgradeStrategy} for use with Tomcat.
048 *
049 * @author Violeta Georgieva
050 * @author Rossen Stoyanchev
051 * @since 5.0
052 */
053public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
054
055        private static final String SERVER_CONTAINER_ATTR = "javax.websocket.server.ServerContainer";
056
057
058        @Nullable
059        private Long asyncSendTimeout;
060
061        @Nullable
062        private Long maxSessionIdleTimeout;
063
064        @Nullable
065        private Integer maxTextMessageBufferSize;
066
067        @Nullable
068        private Integer maxBinaryMessageBufferSize;
069
070        @Nullable
071        private WsServerContainer serverContainer;
072
073
074        /**
075         * Exposes the underlying config option on
076         * {@link javax.websocket.server.ServerContainer#setAsyncSendTimeout(long)}.
077         */
078        public void setAsyncSendTimeout(Long timeoutInMillis) {
079                this.asyncSendTimeout = timeoutInMillis;
080        }
081
082        @Nullable
083        public Long getAsyncSendTimeout() {
084                return this.asyncSendTimeout;
085        }
086
087        /**
088         * Exposes the underlying config option on
089         * {@link javax.websocket.server.ServerContainer#setDefaultMaxSessionIdleTimeout(long)}.
090         */
091        public void setMaxSessionIdleTimeout(Long timeoutInMillis) {
092                this.maxSessionIdleTimeout = timeoutInMillis;
093        }
094
095        @Nullable
096        public Long getMaxSessionIdleTimeout() {
097                return this.maxSessionIdleTimeout;
098        }
099
100        /**
101         * Exposes the underlying config option on
102         * {@link javax.websocket.server.ServerContainer#setDefaultMaxTextMessageBufferSize(int)}.
103         */
104        public void setMaxTextMessageBufferSize(Integer bufferSize) {
105                this.maxTextMessageBufferSize = bufferSize;
106        }
107
108        @Nullable
109        public Integer getMaxTextMessageBufferSize() {
110                return this.maxTextMessageBufferSize;
111        }
112
113        /**
114         * Exposes the underlying config option on
115         * {@link javax.websocket.server.ServerContainer#setDefaultMaxBinaryMessageBufferSize(int)}.
116         */
117        public void setMaxBinaryMessageBufferSize(Integer bufferSize) {
118                this.maxBinaryMessageBufferSize = bufferSize;
119        }
120
121        @Nullable
122        public Integer getMaxBinaryMessageBufferSize() {
123                return this.maxBinaryMessageBufferSize;
124        }
125
126
127        @Override
128        public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
129                        @Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory){
130
131                ServerHttpRequest request = exchange.getRequest();
132                ServerHttpResponse response = exchange.getResponse();
133
134                HttpServletRequest servletRequest = getNativeRequest(request);
135                HttpServletResponse servletResponse = getNativeResponse(response);
136
137                HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
138                DataBufferFactory bufferFactory = response.bufferFactory();
139
140                Endpoint endpoint = new StandardWebSocketHandlerAdapter(
141                                handler, session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));
142
143                String requestURI = servletRequest.getRequestURI();
144                DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
145                config.setSubprotocols(subProtocol != null ?
146                                Collections.singletonList(subProtocol) : Collections.emptyList());
147
148                // Trigger WebFlux preCommit actions and upgrade
149                return exchange.getResponse().setComplete()
150                                .then(Mono.fromCallable(() -> {
151                                        WsServerContainer container = getContainer(servletRequest);
152                                        container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
153                                        return null;
154                                }));
155        }
156
157        private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
158                if (request instanceof AbstractServerHttpRequest) {
159                        return ((AbstractServerHttpRequest) request).getNativeRequest();
160                }
161                else if (request instanceof ServerHttpRequestDecorator) {
162                        return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
163                }
164                else {
165                        throw new IllegalArgumentException(
166                                        "Couldn't find HttpServletRequest in " + request.getClass().getName());
167                }
168        }
169
170        private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
171                if (response instanceof AbstractServerHttpResponse) {
172                        return ((AbstractServerHttpResponse) response).getNativeResponse();
173                }
174                else if (response instanceof ServerHttpResponseDecorator) {
175                        return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
176                }
177                else {
178                        throw new IllegalArgumentException(
179                                        "Couldn't find HttpServletResponse in " + response.getClass().getName());
180                }
181        }
182
183        private WsServerContainer getContainer(HttpServletRequest request) {
184                if (this.serverContainer == null) {
185                        Object container = request.getServletContext().getAttribute(SERVER_CONTAINER_ATTR);
186                        Assert.state(container instanceof WsServerContainer,
187                                        "ServletContext attribute 'javax.websocket.server.ServerContainer' not found.");
188                        this.serverContainer = (WsServerContainer) container;
189                        initServerContainer(this.serverContainer);
190                }
191                return this.serverContainer;
192        }
193
194        private void initServerContainer(ServerContainer serverContainer) {
195                if (this.asyncSendTimeout != null) {
196                        serverContainer.setAsyncSendTimeout(this.asyncSendTimeout);
197                }
198                if (this.maxSessionIdleTimeout != null) {
199                        serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout);
200                }
201                if (this.maxTextMessageBufferSize != null) {
202                        serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize);
203                }
204                if (this.maxBinaryMessageBufferSize != null) {
205                        serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize);
206                }
207        }
208
209}