001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.server.reactive;
018
019import java.io.IOException;
020import java.net.URISyntaxException;
021import java.util.Collection;
022import java.util.concurrent.atomic.AtomicBoolean;
023
024import javax.servlet.AsyncContext;
025import javax.servlet.AsyncEvent;
026import javax.servlet.AsyncListener;
027import javax.servlet.DispatcherType;
028import javax.servlet.Servlet;
029import javax.servlet.ServletConfig;
030import javax.servlet.ServletException;
031import javax.servlet.ServletRegistration;
032import javax.servlet.ServletRequest;
033import javax.servlet.ServletResponse;
034import javax.servlet.http.HttpServlet;
035import javax.servlet.http.HttpServletRequest;
036import javax.servlet.http.HttpServletResponse;
037
038import org.apache.commons.logging.Log;
039import org.reactivestreams.Subscriber;
040import org.reactivestreams.Subscription;
041
042import org.springframework.core.io.buffer.DataBufferFactory;
043import org.springframework.core.io.buffer.DefaultDataBufferFactory;
044import org.springframework.http.HttpLogging;
045import org.springframework.http.HttpMethod;
046import org.springframework.lang.Nullable;
047import org.springframework.util.Assert;
048
049/**
050 * Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async support
051 * and Servlet 3.1 non-blocking I/O.
052 *
053 * @author Arjen Poutsma
054 * @author Rossen Stoyanchev
055 * @since 5.0
056 * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer
057 */
058public class ServletHttpHandlerAdapter implements Servlet {
059
060        private static final Log logger = HttpLogging.forLogName(ServletHttpHandlerAdapter.class);
061
062        private static final int DEFAULT_BUFFER_SIZE = 8192;
063
064        private static final String WRITE_ERROR_ATTRIBUTE_NAME = ServletHttpHandlerAdapter.class.getName() + ".ERROR";
065
066
067        private final HttpHandler httpHandler;
068
069        private int bufferSize = DEFAULT_BUFFER_SIZE;
070
071        @Nullable
072        private String servletPath;
073
074        private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(false);
075
076
077        public ServletHttpHandlerAdapter(HttpHandler httpHandler) {
078                Assert.notNull(httpHandler, "HttpHandler must not be null");
079                this.httpHandler = httpHandler;
080        }
081
082
083        /**
084         * Set the size of the input buffer used for reading in bytes.
085         * <p>By default this is set to 8192.
086         */
087        public void setBufferSize(int bufferSize) {
088                Assert.isTrue(bufferSize > 0, "Buffer size must be larger than zero");
089                this.bufferSize = bufferSize;
090        }
091
092        /**
093         * Return the configured input buffer size.
094         */
095        public int getBufferSize() {
096                return this.bufferSize;
097        }
098
099        /**
100         * Return the Servlet path under which the Servlet is deployed by checking
101         * the Servlet registration from {@link #init(ServletConfig)}.
102         * @return the path, or an empty string if the Servlet is deployed without
103         * a prefix (i.e. "/" or "/*"), or {@code null} if this method is invoked
104         * before the {@link #init(ServletConfig)} Servlet container callback.
105         */
106        @Nullable
107        public String getServletPath() {
108                return this.servletPath;
109        }
110
111        public void setDataBufferFactory(DataBufferFactory dataBufferFactory) {
112                Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null");
113                this.dataBufferFactory = dataBufferFactory;
114        }
115
116        public DataBufferFactory getDataBufferFactory() {
117                return this.dataBufferFactory;
118        }
119
120
121        // Servlet methods...
122
123        @Override
124        public void init(ServletConfig config) {
125                this.servletPath = getServletPath(config);
126        }
127
128        private String getServletPath(ServletConfig config) {
129                String name = config.getServletName();
130                ServletRegistration registration = config.getServletContext().getServletRegistration(name);
131                if (registration == null) {
132                        throw new IllegalStateException("ServletRegistration not found for Servlet '" + name + "'");
133                }
134
135                Collection<String> mappings = registration.getMappings();
136                if (mappings.size() == 1) {
137                        String mapping = mappings.iterator().next();
138                        if (mapping.equals("/")) {
139                                return "";
140                        }
141                        if (mapping.endsWith("/*")) {
142                                String path = mapping.substring(0, mapping.length() - 2);
143                                if (!path.isEmpty() && logger.isDebugEnabled()) {
144                                        logger.debug("Found servlet mapping prefix '" + path + "' for '" + name + "'");
145                                }
146                                return path;
147                        }
148                }
149
150                throw new IllegalArgumentException("Expected a single Servlet mapping: " +
151                                "either the default Servlet mapping (i.e. '/'), " +
152                                "or a path based mapping (e.g. '/*', '/foo/*'). " +
153                                "Actual mappings: " + mappings + " for Servlet '" + name + "'");
154        }
155
156
157        @Override
158        public void service(ServletRequest request, ServletResponse response) throws ServletException, IOException {
159                // Check for existing error attribute first
160                if (DispatcherType.ASYNC.equals(request.getDispatcherType())) {
161                        Throwable ex = (Throwable) request.getAttribute(WRITE_ERROR_ATTRIBUTE_NAME);
162                        throw new ServletException("Failed to create response content", ex);
163                }
164
165                // Start async before Read/WriteListener registration
166                AsyncContext asyncContext = request.startAsync();
167                asyncContext.setTimeout(-1);
168
169                ServletServerHttpRequest httpRequest;
170                try {
171                        httpRequest = createRequest(((HttpServletRequest) request), asyncContext);
172                }
173                catch (URISyntaxException ex) {
174                        if (logger.isDebugEnabled()) {
175                                logger.debug("Failed to get request  URL: " + ex.getMessage());
176                        }
177                        ((HttpServletResponse) response).setStatus(400);
178                        asyncContext.complete();
179                        return;
180                }
181
182                ServerHttpResponse httpResponse = createResponse(((HttpServletResponse) response), asyncContext, httpRequest);
183                if (httpRequest.getMethod() == HttpMethod.HEAD) {
184                        httpResponse = new HttpHeadResponseDecorator(httpResponse);
185                }
186
187                AtomicBoolean isCompleted = new AtomicBoolean();
188                HandlerResultAsyncListener listener = new HandlerResultAsyncListener(isCompleted, httpRequest);
189                asyncContext.addListener(listener);
190
191                HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext, isCompleted, httpRequest);
192                this.httpHandler.handle(httpRequest, httpResponse).subscribe(subscriber);
193        }
194
195        protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context)
196                        throws IOException, URISyntaxException {
197
198                Assert.notNull(this.servletPath, "Servlet path is not initialized");
199                return new ServletServerHttpRequest(
200                                request, context, this.servletPath, getDataBufferFactory(), getBufferSize());
201        }
202
203        protected ServletServerHttpResponse createResponse(HttpServletResponse response,
204                        AsyncContext context, ServletServerHttpRequest request) throws IOException {
205
206                return new ServletServerHttpResponse(response, context, getDataBufferFactory(), getBufferSize(), request);
207        }
208
209        @Override
210        public String getServletInfo() {
211                return "";
212        }
213
214        @Override
215        @Nullable
216        public ServletConfig getServletConfig() {
217                return null;
218        }
219
220        @Override
221        public void destroy() {
222        }
223
224
225        /**
226         * We cannot combine ERROR_LISTENER and HandlerResultSubscriber due to:
227         * https://issues.jboss.org/browse/WFLY-8515.
228         */
229        private static void runIfAsyncNotComplete(AsyncContext asyncContext, AtomicBoolean isCompleted, Runnable task) {
230                try {
231                        if (asyncContext.getRequest().isAsyncStarted() && isCompleted.compareAndSet(false, true)) {
232                                task.run();
233                        }
234                }
235                catch (IllegalStateException ex) {
236                        // Ignore: AsyncContext recycled and should not be used
237                        // e.g. TIMEOUT_LISTENER (above) may have completed the AsyncContext
238                }
239        }
240
241
242        private static class HandlerResultAsyncListener implements AsyncListener {
243
244                private final AtomicBoolean isCompleted;
245
246                private final String logPrefix;
247
248                public HandlerResultAsyncListener(AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) {
249                        this.isCompleted = isCompleted;
250                        this.logPrefix = httpRequest.getLogPrefix();
251                }
252
253                @Override
254                public void onTimeout(AsyncEvent event) {
255                        logger.debug(this.logPrefix + "Timeout notification");
256                        AsyncContext context = event.getAsyncContext();
257                        runIfAsyncNotComplete(context, this.isCompleted, context::complete);
258                }
259
260                @Override
261                public void onError(AsyncEvent event) {
262                        Throwable ex = event.getThrowable();
263                        logger.debug(this.logPrefix + "Error notification: " + (ex != null ? ex : "<no Throwable>"));
264                        AsyncContext context = event.getAsyncContext();
265                        runIfAsyncNotComplete(context, this.isCompleted, context::complete);
266                }
267
268                @Override
269                public void onStartAsync(AsyncEvent event) {
270                        // no-op
271                }
272
273                @Override
274                public void onComplete(AsyncEvent event) {
275                        // no-op
276                }
277        }
278
279
280        private class HandlerResultSubscriber implements Subscriber<Void> {
281
282                private final AsyncContext asyncContext;
283
284                private final AtomicBoolean isCompleted;
285
286                private final String logPrefix;
287
288                public HandlerResultSubscriber(
289                                AsyncContext asyncContext, AtomicBoolean isCompleted, ServletServerHttpRequest httpRequest) {
290
291                        this.asyncContext = asyncContext;
292                        this.isCompleted = isCompleted;
293                        this.logPrefix = httpRequest.getLogPrefix();
294                }
295
296                @Override
297                public void onSubscribe(Subscription subscription) {
298                        subscription.request(Long.MAX_VALUE);
299                }
300
301                @Override
302                public void onNext(Void aVoid) {
303                        // no-op
304                }
305
306                @Override
307                public void onError(Throwable ex) {
308                        logger.trace(this.logPrefix + "Failed to complete: " + ex.getMessage());
309                        runIfAsyncNotComplete(this.asyncContext, this.isCompleted, () -> {
310                                if (this.asyncContext.getResponse().isCommitted()) {
311                                        logger.trace(this.logPrefix + "Dispatch to container, to raise the error on servlet thread");
312                                        this.asyncContext.getRequest().setAttribute(WRITE_ERROR_ATTRIBUTE_NAME, ex);
313                                        this.asyncContext.dispatch();
314                                }
315                                else {
316                                        try {
317                                                logger.trace(this.logPrefix + "Setting ServletResponse status to 500 Server Error");
318                                                this.asyncContext.getResponse().resetBuffer();
319                                                ((HttpServletResponse) this.asyncContext.getResponse()).setStatus(500);
320                                        }
321                                        finally {
322                                                this.asyncContext.complete();
323                                        }
324                                }
325                        });
326                }
327
328                @Override
329                public void onComplete() {
330                        logger.trace(this.logPrefix + "Handling completed");
331                        runIfAsyncNotComplete(this.asyncContext, this.isCompleted, this.asyncContext::complete);
332                }
333        }
334
335}