001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.server.upgrade; 018 019import java.util.Collections; 020import java.util.function.Supplier; 021 022import javax.servlet.http.HttpServletRequest; 023import javax.servlet.http.HttpServletResponse; 024import javax.websocket.Endpoint; 025import javax.websocket.server.ServerContainer; 026 027import org.apache.tomcat.websocket.server.WsServerContainer; 028import reactor.core.publisher.Mono; 029 030import org.springframework.core.io.buffer.DataBufferFactory; 031import org.springframework.http.server.reactive.AbstractServerHttpRequest; 032import org.springframework.http.server.reactive.AbstractServerHttpResponse; 033import org.springframework.http.server.reactive.ServerHttpRequest; 034import org.springframework.http.server.reactive.ServerHttpRequestDecorator; 035import org.springframework.http.server.reactive.ServerHttpResponse; 036import org.springframework.http.server.reactive.ServerHttpResponseDecorator; 037import org.springframework.lang.Nullable; 038import org.springframework.util.Assert; 039import org.springframework.web.reactive.socket.HandshakeInfo; 040import org.springframework.web.reactive.socket.WebSocketHandler; 041import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; 042import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession; 043import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; 044import org.springframework.web.server.ServerWebExchange; 045 046/** 047 * A {@link RequestUpgradeStrategy} for use with Tomcat. 048 * 049 * @author Violeta Georgieva 050 * @author Rossen Stoyanchev 051 * @since 5.0 052 */ 053public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { 054 055 private static final String SERVER_CONTAINER_ATTR = "javax.websocket.server.ServerContainer"; 056 057 058 @Nullable 059 private Long asyncSendTimeout; 060 061 @Nullable 062 private Long maxSessionIdleTimeout; 063 064 @Nullable 065 private Integer maxTextMessageBufferSize; 066 067 @Nullable 068 private Integer maxBinaryMessageBufferSize; 069 070 @Nullable 071 private WsServerContainer serverContainer; 072 073 074 /** 075 * Exposes the underlying config option on 076 * {@link javax.websocket.server.ServerContainer#setAsyncSendTimeout(long)}. 077 */ 078 public void setAsyncSendTimeout(Long timeoutInMillis) { 079 this.asyncSendTimeout = timeoutInMillis; 080 } 081 082 @Nullable 083 public Long getAsyncSendTimeout() { 084 return this.asyncSendTimeout; 085 } 086 087 /** 088 * Exposes the underlying config option on 089 * {@link javax.websocket.server.ServerContainer#setDefaultMaxSessionIdleTimeout(long)}. 090 */ 091 public void setMaxSessionIdleTimeout(Long timeoutInMillis) { 092 this.maxSessionIdleTimeout = timeoutInMillis; 093 } 094 095 @Nullable 096 public Long getMaxSessionIdleTimeout() { 097 return this.maxSessionIdleTimeout; 098 } 099 100 /** 101 * Exposes the underlying config option on 102 * {@link javax.websocket.server.ServerContainer#setDefaultMaxTextMessageBufferSize(int)}. 103 */ 104 public void setMaxTextMessageBufferSize(Integer bufferSize) { 105 this.maxTextMessageBufferSize = bufferSize; 106 } 107 108 @Nullable 109 public Integer getMaxTextMessageBufferSize() { 110 return this.maxTextMessageBufferSize; 111 } 112 113 /** 114 * Exposes the underlying config option on 115 * {@link javax.websocket.server.ServerContainer#setDefaultMaxBinaryMessageBufferSize(int)}. 116 */ 117 public void setMaxBinaryMessageBufferSize(Integer bufferSize) { 118 this.maxBinaryMessageBufferSize = bufferSize; 119 } 120 121 @Nullable 122 public Integer getMaxBinaryMessageBufferSize() { 123 return this.maxBinaryMessageBufferSize; 124 } 125 126 127 @Override 128 public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler, 129 @Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory){ 130 131 ServerHttpRequest request = exchange.getRequest(); 132 ServerHttpResponse response = exchange.getResponse(); 133 134 HttpServletRequest servletRequest = getNativeRequest(request); 135 HttpServletResponse servletResponse = getNativeResponse(response); 136 137 HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); 138 DataBufferFactory bufferFactory = response.bufferFactory(); 139 140 Endpoint endpoint = new StandardWebSocketHandlerAdapter( 141 handler, session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory)); 142 143 String requestURI = servletRequest.getRequestURI(); 144 DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint); 145 config.setSubprotocols(subProtocol != null ? 146 Collections.singletonList(subProtocol) : Collections.emptyList()); 147 148 // Trigger WebFlux preCommit actions and upgrade 149 return exchange.getResponse().setComplete() 150 .then(Mono.fromCallable(() -> { 151 WsServerContainer container = getContainer(servletRequest); 152 container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap()); 153 return null; 154 })); 155 } 156 157 private static HttpServletRequest getNativeRequest(ServerHttpRequest request) { 158 if (request instanceof AbstractServerHttpRequest) { 159 return ((AbstractServerHttpRequest) request).getNativeRequest(); 160 } 161 else if (request instanceof ServerHttpRequestDecorator) { 162 return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate()); 163 } 164 else { 165 throw new IllegalArgumentException( 166 "Couldn't find HttpServletRequest in " + request.getClass().getName()); 167 } 168 } 169 170 private static HttpServletResponse getNativeResponse(ServerHttpResponse response) { 171 if (response instanceof AbstractServerHttpResponse) { 172 return ((AbstractServerHttpResponse) response).getNativeResponse(); 173 } 174 else if (response instanceof ServerHttpResponseDecorator) { 175 return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate()); 176 } 177 else { 178 throw new IllegalArgumentException( 179 "Couldn't find HttpServletResponse in " + response.getClass().getName()); 180 } 181 } 182 183 private WsServerContainer getContainer(HttpServletRequest request) { 184 if (this.serverContainer == null) { 185 Object container = request.getServletContext().getAttribute(SERVER_CONTAINER_ATTR); 186 Assert.state(container instanceof WsServerContainer, 187 "ServletContext attribute 'javax.websocket.server.ServerContainer' not found."); 188 this.serverContainer = (WsServerContainer) container; 189 initServerContainer(this.serverContainer); 190 } 191 return this.serverContainer; 192 } 193 194 private void initServerContainer(ServerContainer serverContainer) { 195 if (this.asyncSendTimeout != null) { 196 serverContainer.setAsyncSendTimeout(this.asyncSendTimeout); 197 } 198 if (this.maxSessionIdleTimeout != null) { 199 serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout); 200 } 201 if (this.maxTextMessageBufferSize != null) { 202 serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize); 203 } 204 if (this.maxBinaryMessageBufferSize != null) { 205 serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize); 206 } 207 } 208 209}