001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.autoconfigure.web.reactive.error;
018
019import java.util.Collections;
020import java.util.EnumMap;
021import java.util.List;
022import java.util.Map;
023
024import org.apache.commons.logging.Log;
025import reactor.core.publisher.Flux;
026import reactor.core.publisher.Mono;
027
028import org.springframework.boot.autoconfigure.web.ErrorProperties;
029import org.springframework.boot.autoconfigure.web.ResourceProperties;
030import org.springframework.boot.web.reactive.error.ErrorAttributes;
031import org.springframework.context.ApplicationContext;
032import org.springframework.http.HttpLogging;
033import org.springframework.http.HttpStatus;
034import org.springframework.http.InvalidMediaTypeException;
035import org.springframework.http.MediaType;
036import org.springframework.web.reactive.function.BodyInserters;
037import org.springframework.web.reactive.function.server.RequestPredicate;
038import org.springframework.web.reactive.function.server.RouterFunction;
039import org.springframework.web.reactive.function.server.ServerRequest;
040import org.springframework.web.reactive.function.server.ServerResponse;
041
042import static org.springframework.web.reactive.function.server.RequestPredicates.all;
043import static org.springframework.web.reactive.function.server.RouterFunctions.route;
044
045/**
046 * Basic global {@link org.springframework.web.server.WebExceptionHandler}, rendering
047 * {@link ErrorAttributes}.
048 * <p>
049 * More specific errors can be handled either using Spring WebFlux abstractions (e.g.
050 * {@code @ExceptionHandler} with the annotation model) or by adding
051 * {@link RouterFunction} to the chain.
052 * <p>
053 * This implementation will render error as HTML views if the client explicitly supports
054 * that media type. It attempts to resolve error views using well known conventions. Will
055 * search for templates and static assets under {@code '/error'} using the
056 * {@link HttpStatus status code} and the {@link HttpStatus#series() status series}.
057 * <p>
058 * For example, an {@code HTTP 404} will search (in the specific order):
059 * <ul>
060 * <li>{@code '/<templates>/error/404.<ext>'}</li>
061 * <li>{@code '/<static>/error/404.html'}</li>
062 * <li>{@code '/<templates>/error/4xx.<ext>'}</li>
063 * <li>{@code '/<static>/error/4xx.html'}</li>
064 * <li>{@code '/<templates>/error/error'}</li>
065 * <li>{@code '/<static>/error/error.html'}</li>
066 * </ul>
067 * <p>
068 * If none found, a default "Whitelabel Error" HTML view will be rendered.
069 * <p>
070 * If the client doesn't support HTML, the error information will be rendered as a JSON
071 * payload.
072 *
073 * @author Brian Clozel
074 * @since 2.0.0
075 */
076public class DefaultErrorWebExceptionHandler extends AbstractErrorWebExceptionHandler {
077
078        private static final Map<HttpStatus.Series, String> SERIES_VIEWS;
079
080        private static final Log logger = HttpLogging
081                        .forLogName(DefaultErrorWebExceptionHandler.class);
082
083        static {
084                Map<HttpStatus.Series, String> views = new EnumMap<>(HttpStatus.Series.class);
085                views.put(HttpStatus.Series.CLIENT_ERROR, "4xx");
086                views.put(HttpStatus.Series.SERVER_ERROR, "5xx");
087                SERIES_VIEWS = Collections.unmodifiableMap(views);
088        }
089
090        private final ErrorProperties errorProperties;
091
092        /**
093         * Create a new {@code DefaultErrorWebExceptionHandler} instance.
094         * @param errorAttributes the error attributes
095         * @param resourceProperties the resources configuration properties
096         * @param errorProperties the error configuration properties
097         * @param applicationContext the current application context
098         */
099        public DefaultErrorWebExceptionHandler(ErrorAttributes errorAttributes,
100                        ResourceProperties resourceProperties, ErrorProperties errorProperties,
101                        ApplicationContext applicationContext) {
102                super(errorAttributes, resourceProperties, applicationContext);
103                this.errorProperties = errorProperties;
104        }
105
106        @Override
107        protected RouterFunction<ServerResponse> getRoutingFunction(
108                        ErrorAttributes errorAttributes) {
109                return route(acceptsTextHtml(), this::renderErrorView).andRoute(all(),
110                                this::renderErrorResponse);
111        }
112
113        /**
114         * Render the error information as an HTML view.
115         * @param request the current request
116         * @return a {@code Publisher} of the HTTP response
117         */
118        protected Mono<ServerResponse> renderErrorView(ServerRequest request) {
119                boolean includeStackTrace = isIncludeStackTrace(request, MediaType.TEXT_HTML);
120                Map<String, Object> error = getErrorAttributes(request, includeStackTrace);
121                HttpStatus errorStatus = getHttpStatus(error);
122                ServerResponse.BodyBuilder responseBody = ServerResponse.status(errorStatus)
123                                .contentType(MediaType.TEXT_HTML);
124                return Flux
125                                .just("error/" + errorStatus.value(),
126                                                "error/" + SERIES_VIEWS.get(errorStatus.series()), "error/error")
127                                .flatMap((viewName) -> renderErrorView(viewName, responseBody, error))
128                                .switchIfEmpty(this.errorProperties.getWhitelabel().isEnabled()
129                                                ? renderDefaultErrorView(responseBody, error)
130                                                : Mono.error(getError(request)))
131                                .next().doOnNext((response) -> logError(request, errorStatus));
132        }
133
134        /**
135         * Render the error information as a JSON payload.
136         * @param request the current request
137         * @return a {@code Publisher} of the HTTP response
138         */
139        protected Mono<ServerResponse> renderErrorResponse(ServerRequest request) {
140                boolean includeStackTrace = isIncludeStackTrace(request, MediaType.ALL);
141                Map<String, Object> error = getErrorAttributes(request, includeStackTrace);
142                HttpStatus errorStatus = getHttpStatus(error);
143                return ServerResponse.status(getHttpStatus(error))
144                                .contentType(MediaType.APPLICATION_JSON_UTF8)
145                                .body(BodyInserters.fromObject(error))
146                                .doOnNext((resp) -> logError(request, errorStatus));
147        }
148
149        /**
150         * Determine if the stacktrace attribute should be included.
151         * @param request the source request
152         * @param produces the media type produced (or {@code MediaType.ALL})
153         * @return if the stacktrace attribute should be included
154         */
155        protected boolean isIncludeStackTrace(ServerRequest request, MediaType produces) {
156                ErrorProperties.IncludeStacktrace include = this.errorProperties
157                                .getIncludeStacktrace();
158                if (include == ErrorProperties.IncludeStacktrace.ALWAYS) {
159                        return true;
160                }
161                if (include == ErrorProperties.IncludeStacktrace.ON_TRACE_PARAM) {
162                        return isTraceEnabled(request);
163                }
164                return false;
165        }
166
167        /**
168         * Get the HTTP error status information from the error map.
169         * @param errorAttributes the current error information
170         * @return the error HTTP status
171         */
172        protected HttpStatus getHttpStatus(Map<String, Object> errorAttributes) {
173                int statusCode = (int) errorAttributes.get("status");
174                return HttpStatus.valueOf(statusCode);
175        }
176
177        /**
178         * Predicate that checks whether the current request explicitly support
179         * {@code "text/html"} media type.
180         * <p>
181         * The "match-all" media type is not considered here.
182         * @return the request predicate
183         */
184        protected RequestPredicate acceptsTextHtml() {
185                return (serverRequest) -> {
186                        try {
187                                List<MediaType> acceptedMediaTypes = serverRequest.headers().accept();
188                                acceptedMediaTypes.remove(MediaType.ALL);
189                                MediaType.sortBySpecificityAndQuality(acceptedMediaTypes);
190                                return acceptedMediaTypes.stream()
191                                                .anyMatch(MediaType.TEXT_HTML::isCompatibleWith);
192                        }
193                        catch (InvalidMediaTypeException ex) {
194                                return false;
195                        }
196                };
197        }
198
199        /**
200         * Log the original exception if handling it results in a Server Error or a Bad
201         * Request (Client Error with 400 status code) one.
202         * @param request the source request
203         * @param errorStatus the HTTP error status
204         */
205        protected void logError(ServerRequest request, HttpStatus errorStatus) {
206                Throwable ex = getError(request);
207                if (logger.isDebugEnabled()) {
208                        logger.debug(request.exchange().getLogPrefix() + formatError(ex, request));
209                }
210        }
211
212        private String formatError(Throwable ex, ServerRequest request) {
213                String reason = ex.getClass().getSimpleName() + ": " + ex.getMessage();
214                return "Resolved [" + reason + "] for HTTP " + request.methodName() + " "
215                                + request.path();
216        }
217
218}