001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.web.servlet.support;
018
019import java.io.IOException;
020import java.io.PrintWriter;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.HashSet;
025import java.util.Map;
026import java.util.Set;
027
028import javax.servlet.Filter;
029import javax.servlet.FilterChain;
030import javax.servlet.FilterConfig;
031import javax.servlet.ServletException;
032import javax.servlet.ServletOutputStream;
033import javax.servlet.ServletRequest;
034import javax.servlet.ServletResponse;
035import javax.servlet.http.HttpServletRequest;
036import javax.servlet.http.HttpServletResponse;
037import javax.servlet.http.HttpServletResponseWrapper;
038
039import org.apache.commons.logging.Log;
040import org.apache.commons.logging.LogFactory;
041
042import org.springframework.boot.web.server.ErrorPage;
043import org.springframework.boot.web.server.ErrorPageRegistrar;
044import org.springframework.boot.web.server.ErrorPageRegistry;
045import org.springframework.core.Ordered;
046import org.springframework.core.annotation.Order;
047import org.springframework.util.ClassUtils;
048import org.springframework.web.filter.OncePerRequestFilter;
049import org.springframework.web.util.NestedServletException;
050
051/**
052 * A Servlet {@link Filter} that provides an {@link ErrorPageRegistry} for non-embedded
053 * applications (i.e. deployed WAR files). It registers error pages and handles
054 * application errors by filtering requests and forwarding to the error pages instead of
055 * letting the server handle them. Error pages are a feature of the servlet spec but there
056 * is no Java API for registering them in the spec. This filter works around that by
057 * accepting error page registrations from Spring Boot's {@link ErrorPageRegistrar} (any
058 * beans of that type in the context will be applied to this server).
059 *
060 * @author Dave Syer
061 * @author Phillip Webb
062 * @author Andy Wilkinson
063 * @since 2.0.0
064 */
065@Order(Ordered.HIGHEST_PRECEDENCE + 1)
066public class ErrorPageFilter implements Filter, ErrorPageRegistry {
067
068        private static final Log logger = LogFactory.getLog(ErrorPageFilter.class);
069
070        // From RequestDispatcher but not referenced to remain compatible with Servlet 2.5
071
072        private static final String ERROR_EXCEPTION = "javax.servlet.error.exception";
073
074        private static final String ERROR_EXCEPTION_TYPE = "javax.servlet.error.exception_type";
075
076        private static final String ERROR_MESSAGE = "javax.servlet.error.message";
077
078        /**
079         * The name of the servlet attribute containing request URI.
080         */
081        public static final String ERROR_REQUEST_URI = "javax.servlet.error.request_uri";
082
083        private static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code";
084
085        private static final Set<Class<?>> CLIENT_ABORT_EXCEPTIONS;
086        static {
087                Set<Class<?>> clientAbortExceptions = new HashSet<>();
088                addClassIfPresent(clientAbortExceptions,
089                                "org.apache.catalina.connector.ClientAbortException");
090                CLIENT_ABORT_EXCEPTIONS = Collections.unmodifiableSet(clientAbortExceptions);
091        }
092
093        private String global;
094
095        private final Map<Integer, String> statuses = new HashMap<>();
096
097        private final Map<Class<?>, String> exceptions = new HashMap<>();
098
099        private final OncePerRequestFilter delegate = new OncePerRequestFilter() {
100
101                @Override
102                protected void doFilterInternal(HttpServletRequest request,
103                                HttpServletResponse response, FilterChain chain)
104                                throws ServletException, IOException {
105                        ErrorPageFilter.this.doFilter(request, response, chain);
106                }
107
108                @Override
109                protected boolean shouldNotFilterAsyncDispatch() {
110                        return false;
111                }
112
113        };
114
115        @Override
116        public void init(FilterConfig filterConfig) throws ServletException {
117                this.delegate.init(filterConfig);
118        }
119
120        @Override
121        public void doFilter(ServletRequest request, ServletResponse response,
122                        FilterChain chain) throws IOException, ServletException {
123                this.delegate.doFilter(request, response, chain);
124        }
125
126        private void doFilter(HttpServletRequest request, HttpServletResponse response,
127                        FilterChain chain) throws IOException, ServletException {
128                ErrorWrapperResponse wrapped = new ErrorWrapperResponse(response);
129                try {
130                        chain.doFilter(request, wrapped);
131                        if (wrapped.hasErrorToSend()) {
132                                handleErrorStatus(request, response, wrapped.getStatus(),
133                                                wrapped.getMessage());
134                                response.flushBuffer();
135                        }
136                        else if (!request.isAsyncStarted() && !response.isCommitted()) {
137                                response.flushBuffer();
138                        }
139                }
140                catch (Throwable ex) {
141                        Throwable exceptionToHandle = ex;
142                        if (ex instanceof NestedServletException) {
143                                exceptionToHandle = ((NestedServletException) ex).getRootCause();
144                        }
145                        handleException(request, response, wrapped, exceptionToHandle);
146                        response.flushBuffer();
147                }
148        }
149
150        private void handleErrorStatus(HttpServletRequest request,
151                        HttpServletResponse response, int status, String message)
152                        throws ServletException, IOException {
153                if (response.isCommitted()) {
154                        handleCommittedResponse(request, null);
155                        return;
156                }
157                String errorPath = getErrorPath(this.statuses, status);
158                if (errorPath == null) {
159                        response.sendError(status, message);
160                        return;
161                }
162                response.setStatus(status);
163                setErrorAttributes(request, status, message);
164                request.getRequestDispatcher(errorPath).forward(request, response);
165        }
166
167        private void handleException(HttpServletRequest request, HttpServletResponse response,
168                        ErrorWrapperResponse wrapped, Throwable ex)
169                        throws IOException, ServletException {
170                Class<?> type = ex.getClass();
171                String errorPath = getErrorPath(type);
172                if (errorPath == null) {
173                        rethrow(ex);
174                        return;
175                }
176                if (response.isCommitted()) {
177                        handleCommittedResponse(request, ex);
178                        return;
179                }
180                forwardToErrorPage(errorPath, request, wrapped, ex);
181        }
182
183        private void forwardToErrorPage(String path, HttpServletRequest request,
184                        HttpServletResponse response, Throwable ex)
185                        throws ServletException, IOException {
186                if (logger.isErrorEnabled()) {
187                        String message = "Forwarding to error page from request "
188                                        + getDescription(request) + " due to exception [" + ex.getMessage()
189                                        + "]";
190                        logger.error(message, ex);
191                }
192                setErrorAttributes(request, 500, ex.getMessage());
193                request.setAttribute(ERROR_EXCEPTION, ex);
194                request.setAttribute(ERROR_EXCEPTION_TYPE, ex.getClass());
195                response.reset();
196                response.setStatus(500);
197                request.getRequestDispatcher(path).forward(request, response);
198                request.removeAttribute(ERROR_EXCEPTION);
199                request.removeAttribute(ERROR_EXCEPTION_TYPE);
200        }
201
202        /**
203         * Return the description for the given request. By default this method will return a
204         * description based on the request {@code servletPath} and {@code pathInfo}.
205         * @param request the source request
206         * @return the description
207         * @since 1.5.0
208         */
209        protected String getDescription(HttpServletRequest request) {
210                String pathInfo = (request.getPathInfo() != null) ? request.getPathInfo() : "";
211                return "[" + request.getServletPath() + pathInfo + "]";
212        }
213
214        private void handleCommittedResponse(HttpServletRequest request, Throwable ex) {
215                if (isClientAbortException(ex)) {
216                        return;
217                }
218                String message = "Cannot forward to error page for request "
219                                + getDescription(request) + " as the response has already been"
220                                + " committed. As a result, the response may have the wrong status"
221                                + " code. If your application is running on WebSphere Application"
222                                + " Server you may be able to resolve this problem by setting"
223                                + " com.ibm.ws.webcontainer.invokeFlushAfterService to false";
224                if (ex == null) {
225                        logger.error(message);
226                }
227                else {
228                        // User might see the error page without all the data here but throwing the
229                        // exception isn't going to help anyone (we'll log it to be on the safe side)
230                        logger.error(message, ex);
231                }
232        }
233
234        private boolean isClientAbortException(Throwable ex) {
235                if (ex == null) {
236                        return false;
237                }
238                for (Class<?> candidate : CLIENT_ABORT_EXCEPTIONS) {
239                        if (candidate.isInstance(ex)) {
240                                return true;
241                        }
242                }
243                return isClientAbortException(ex.getCause());
244        }
245
246        private String getErrorPath(Map<Integer, String> map, Integer status) {
247                if (map.containsKey(status)) {
248                        return map.get(status);
249                }
250                return this.global;
251        }
252
253        private String getErrorPath(Class<?> type) {
254                while (type != Object.class) {
255                        String path = this.exceptions.get(type);
256                        if (path != null) {
257                                return path;
258                        }
259                        type = type.getSuperclass();
260                }
261                return this.global;
262        }
263
264        private void setErrorAttributes(HttpServletRequest request, int status,
265                        String message) {
266                request.setAttribute(ERROR_STATUS_CODE, status);
267                request.setAttribute(ERROR_MESSAGE, message);
268                request.setAttribute(ERROR_REQUEST_URI, request.getRequestURI());
269        }
270
271        private void rethrow(Throwable ex) throws IOException, ServletException {
272                if (ex instanceof RuntimeException) {
273                        throw (RuntimeException) ex;
274                }
275                if (ex instanceof Error) {
276                        throw (Error) ex;
277                }
278                if (ex instanceof IOException) {
279                        throw (IOException) ex;
280                }
281                if (ex instanceof ServletException) {
282                        throw (ServletException) ex;
283                }
284                throw new IllegalStateException(ex);
285        }
286
287        @Override
288        public void addErrorPages(ErrorPage... errorPages) {
289                for (ErrorPage errorPage : errorPages) {
290                        if (errorPage.isGlobal()) {
291                                this.global = errorPage.getPath();
292                        }
293                        else if (errorPage.getStatus() != null) {
294                                this.statuses.put(errorPage.getStatus().value(), errorPage.getPath());
295                        }
296                        else {
297                                this.exceptions.put(errorPage.getException(), errorPage.getPath());
298                        }
299                }
300        }
301
302        @Override
303        public void destroy() {
304        }
305
306        private static void addClassIfPresent(Collection<Class<?>> collection,
307                        String className) {
308                try {
309                        collection.add(ClassUtils.forName(className, null));
310                }
311                catch (Throwable ex) {
312                }
313        }
314
315        private static class ErrorWrapperResponse extends HttpServletResponseWrapper {
316
317                private int status;
318
319                private String message;
320
321                private boolean hasErrorToSend = false;
322
323                ErrorWrapperResponse(HttpServletResponse response) {
324                        super(response);
325                }
326
327                @Override
328                public void sendError(int status) throws IOException {
329                        sendError(status, null);
330                }
331
332                @Override
333                public void sendError(int status, String message) throws IOException {
334                        this.status = status;
335                        this.message = message;
336                        this.hasErrorToSend = true;
337                        // Do not call super because the container may prevent us from handling the
338                        // error ourselves
339                }
340
341                @Override
342                public int getStatus() {
343                        if (this.hasErrorToSend) {
344                                return this.status;
345                        }
346                        // If there was no error we need to trust the wrapped response
347                        return super.getStatus();
348                }
349
350                @Override
351                public void flushBuffer() throws IOException {
352                        sendErrorIfNecessary();
353                        super.flushBuffer();
354                }
355
356                private void sendErrorIfNecessary() throws IOException {
357                        if (this.hasErrorToSend && !isCommitted()) {
358                                ((HttpServletResponse) getResponse()).sendError(this.status,
359                                                this.message);
360                        }
361                }
362
363                public String getMessage() {
364                        return this.message;
365                }
366
367                public boolean hasErrorToSend() {
368                        return this.hasErrorToSend;
369                }
370
371                @Override
372                public PrintWriter getWriter() throws IOException {
373                        sendErrorIfNecessary();
374                        return super.getWriter();
375
376                }
377
378                @Override
379                public ServletOutputStream getOutputStream() throws IOException {
380                        sendErrorIfNecessary();
381                        return super.getOutputStream();
382                }
383
384        }
385
386}