001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.http.server.reactive; 018 019import java.io.IOException; 020import java.net.URISyntaxException; 021import java.util.Collection; 022import java.util.concurrent.atomic.AtomicBoolean; 023 024import javax.servlet.AsyncContext; 025import javax.servlet.AsyncEvent; 026import javax.servlet.AsyncListener; 027import javax.servlet.DispatcherType; 028import javax.servlet.Servlet; 029import javax.servlet.ServletConfig; 030import javax.servlet.ServletException; 031import javax.servlet.ServletRegistration; 032import javax.servlet.ServletRequest; 033import javax.servlet.ServletResponse; 034import javax.servlet.http.HttpServlet; 035import javax.servlet.http.HttpServletRequest; 036import javax.servlet.http.HttpServletResponse; 037 038import org.apache.commons.logging.Log; 039import org.reactivestreams.Subscriber; 040import org.reactivestreams.Subscription; 041 042import org.springframework.core.io.buffer.DataBufferFactory; 043import org.springframework.core.io.buffer.DefaultDataBufferFactory; 044import org.springframework.http.HttpLogging; 045import org.springframework.http.HttpMethod; 046import org.springframework.lang.Nullable; 047import org.springframework.util.Assert; 048 049/** 050 * Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async support 051 * and Servlet 3.1 non-blocking I/O. 052 * 053 * @author Arjen Poutsma 054 * @author Rossen Stoyanchev 055 * @since 5.0 056 * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer 057 */ 058public class ServletHttpHandlerAdapter implements Servlet { 059 060 private static final Log logger = HttpLogging.forLogName(ServletHttpHandlerAdapter.class); 061 062 private static final int DEFAULT_BUFFER_SIZE = 8192; 063 064 private static final String WRITE_ERROR_ATTRIBUTE_NAME = ServletHttpHandlerAdapter.class.getName() + ".ERROR"; 065 066 067 private final HttpHandler httpHandler; 068 069 private int bufferSize = DEFAULT_BUFFER_SIZE; 070 071 @Nullable 072 private String servletPath; 073 074 private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(false); 075 076 077 public ServletHttpHandlerAdapter(HttpHandler httpHandler) { 078 Assert.notNull(httpHandler, "HttpHandler must not be null"); 079 this.httpHandler = httpHandler; 080 } 081 082 083 /** 084 * Set the size of the input buffer used for reading in bytes. 085 * <p>By default this is set to 8192. 086 */ 087 public void setBufferSize(int bufferSize) { 088 Assert.isTrue(bufferSize > 0, "Buffer size must be larger than zero"); 089 this.bufferSize = bufferSize; 090 } 091 092 /** 093 * Return the configured input buffer size. 094 */ 095 public int getBufferSize() { 096 return this.bufferSize; 097 } 098 099 /** 100 * Return the Servlet path under which the Servlet is deployed by checking 101 * the Servlet registration from {@link #init(ServletConfig)}. 102 * @return the path, or an empty string if the Servlet is deployed without 103 * a prefix (i.e. "/" or "/*"), or {@code null} if this method is invoked 104 * before the {@link #init(ServletConfig)} Servlet container callback. 105 */ 106 @Nullable 107 public String getServletPath() { 108 return this.servletPath; 109 } 110 111 public void setDataBufferFactory(DataBufferFactory dataBufferFactory) { 112 Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); 113 this.dataBufferFactory = dataBufferFactory; 114 } 115 116 public DataBufferFactory getDataBufferFactory() { 117 return this.dataBufferFactory; 118 } 119 120 121 // Servlet methods... 122 123 @Override 124 public void init(ServletConfig config) { 125 this.servletPath = getServletPath(config); 126 } 127 128 private String getServletPath(ServletConfig config) { 129 String name = config.getServletName(); 130 ServletRegistration registration = config.getServletContext().getServletRegistration(name); 131 if (registration == null) { 132 throw new IllegalStateException("ServletRegistration not found for Servlet '" + name + "'"); 133 } 134 135 Collection<String> mappings = registration.getMappings(); 136 if (mappings.size() == 1) { 137 String mapping = mappings.iterator().next(); 138 if (mapping.equals("/")) { 139 return ""; 140 } 141 if (mapping.endsWith("/*")) { 142 String path = mapping.substring(0, mapping.length() - 2); 143 if (!path.isEmpty() && logger.isDebugEnabled()) { 144 logger.debug("Found servlet mapping prefix '" + path + "' for '" + name + "'"); 145 } 146 return path; 147 } 148 } 149 150 throw new IllegalArgumentException("Expected a single Servlet mapping: " + 151 "either the default Servlet mapping (i.e. '/'), " + 152 "or a path based mapping (e.g. '/*', '/foo/*'). " + 153 "Actual mappings: " + mappings + " for Servlet '" + name + "'"); 154 } 155 156 157 @Override 158 public void service(ServletRequest request, ServletResponse response) throws ServletException, IOException { 159 // Check for existing error attribute first 160 if (DispatcherType.ASYNC.equals(request.getDispatcherType())) { 161 Throwable ex = (Throwable) request.getAttribute(WRITE_ERROR_ATTRIBUTE_NAME); 162 throw new ServletException("Failed to create response content", ex); 163 } 164 165 // Start async before Read/WriteListener registration 166 AsyncContext asyncContext = request.startAsync(); 167 asyncContext.setTimeout(-1); 168 169 ServletServerHttpRequest httpRequest; 170 try { 171 httpRequest = createRequest(((HttpServletRequest) request), asyncContext); 172 } 173 catch (URISyntaxException ex) { 174 if (logger.isDebugEnabled()) { 175 logger.debug("Failed to get request URL: " + ex.getMessage()); 176 } 177 ((HttpServletResponse) response).setStatus(400); 178 asyncContext.complete(); 179 return; 180 } 181 182 ServerHttpResponse httpResponse = createResponse(((HttpServletResponse) response), asyncContext, httpRequest); 183 if (httpRequest.getMethod() == HttpMethod.HEAD) { 184 httpResponse = new HttpHeadResponseDecorator(httpResponse); 185 } 186 187 AtomicBoolean isCompleted = new AtomicBoolean(); 188 HandlerResultAsyncListener listener = new HandlerResultAsyncListener(isCompleted, httpRequest); 189 asyncContext.addListener(listener); 190 191 HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext, isCompleted, httpRequest); 192 this.httpHandler.handle(httpRequest, httpResponse).subscribe(subscriber); 193 } 194 195 protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) 196 throws IOException, URISyntaxException { 197 198 Assert.notNull(this.servletPath, "Servlet path is not initialized"); 199 return new ServletServerHttpRequest( 200 request, context, this.servletPath, getDataBufferFactory(), getBufferSize()); 201 } 202 203 protected ServletServerHttpResponse createResponse(HttpServletResponse response, 204 AsyncContext context, ServletServerHttpRequest request) throws IOException { 205 206 return new ServletServerHttpResponse(response, context, getDataBufferFactory(), getBufferSize(), request); 207 } 208 209 @Override 210 public String getServletInfo() { 211 return ""; 212 } 213 214 @Override 215 @Nullable 216 public ServletConfig getServletConfig() { 217 return null; 218 } 219 220 @Override 221 public void destroy() { 222 } 223 224 225 /** 226 * We cannot combine ERROR_LISTENER and HandlerResultSubscriber due to: 227 * https://issues.jboss.org/browse/WFLY-8515. 228 */ 229 private static void runIfAsyncNotComplete(AsyncContext asyncContext, AtomicBoolean isCompleted, Runnable task) { 230 try { 231 if (asyncContext.getRequest().isAsyncStarted() && isCompleted.compareAndSet(false, true)) { 232 task.run(); 233 } 234 } 235 catch (IllegalStateException ex) { 236 // Ignore: AsyncContext recycled and should not be used 237 // e.g. TIMEOUT_LISTENER (above) may have completed the AsyncContext 238 } 239 } 240 241 242 private static class HandlerResultAsyncListener implements AsyncListener { 243 244 private final AtomicBoolean isCompleted; 245 246 private final String logPrefix; 247 248 public HandlerResultAsyncListener(AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) { 249 this.isCompleted = isCompleted; 250 this.logPrefix = httpRequest.getLogPrefix(); 251 } 252 253 @Override 254 public void onTimeout(AsyncEvent event) { 255 logger.debug(this.logPrefix + "Timeout notification"); 256 AsyncContext context = event.getAsyncContext(); 257 runIfAsyncNotComplete(context, this.isCompleted, context::complete); 258 } 259 260 @Override 261 public void onError(AsyncEvent event) { 262 Throwable ex = event.getThrowable(); 263 logger.debug(this.logPrefix + "Error notification: " + (ex != null ? ex : "<no Throwable>")); 264 AsyncContext context = event.getAsyncContext(); 265 runIfAsyncNotComplete(context, this.isCompleted, context::complete); 266 } 267 268 @Override 269 public void onStartAsync(AsyncEvent event) { 270 // no-op 271 } 272 273 @Override 274 public void onComplete(AsyncEvent event) { 275 // no-op 276 } 277 } 278 279 280 private class HandlerResultSubscriber implements Subscriber<Void> { 281 282 private final AsyncContext asyncContext; 283 284 private final AtomicBoolean isCompleted; 285 286 private final String logPrefix; 287 288 public HandlerResultSubscriber( 289 AsyncContext asyncContext, AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) { 290 291 this.asyncContext = asyncContext; 292 this.isCompleted = isCompleted; 293 this.logPrefix = httpRequest.getLogPrefix(); 294 } 295 296 @Override 297 public void onSubscribe(Subscription subscription) { 298 subscription.request(Long.MAX_VALUE); 299 } 300 301 @Override 302 public void onNext(Void aVoid) { 303 // no-op 304 } 305 306 @Override 307 public void onError(Throwable ex) { 308 logger.trace(this.logPrefix + "Failed to complete: " + ex.getMessage()); 309 runIfAsyncNotComplete(this.asyncContext, this.isCompleted, () -> { 310 if (this.asyncContext.getResponse().isCommitted()) { 311 logger.trace(this.logPrefix + "Dispatch to container, to raise the error on servlet thread"); 312 this.asyncContext.getRequest().setAttribute(WRITE_ERROR_ATTRIBUTE_NAME, ex); 313 this.asyncContext.dispatch(); 314 } 315 else { 316 try { 317 logger.trace(this.logPrefix + "Setting ServletResponse status to 500 Server Error"); 318 this.asyncContext.getResponse().resetBuffer(); 319 ((HttpServletResponse) this.asyncContext.getResponse()).setStatus(500); 320 } 321 finally { 322 this.asyncContext.complete(); 323 } 324 } 325 }); 326 } 327 328 @Override 329 public void onComplete() { 330 logger.trace(this.logPrefix + "Handling completed"); 331 runIfAsyncNotComplete(this.asyncContext, this.isCompleted, this.asyncContext::complete); 332 } 333 } 334 335}