001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.actuate.web.trace.reactive;
018
019import java.security.Principal;
020import java.util.Set;
021
022import reactor.core.publisher.Mono;
023
024import org.springframework.boot.actuate.trace.http.HttpExchangeTracer;
025import org.springframework.boot.actuate.trace.http.HttpTrace;
026import org.springframework.boot.actuate.trace.http.HttpTraceRepository;
027import org.springframework.boot.actuate.trace.http.Include;
028import org.springframework.core.Ordered;
029import org.springframework.http.HttpStatus;
030import org.springframework.http.server.reactive.ServerHttpResponse;
031import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
032import org.springframework.web.server.ResponseStatusException;
033import org.springframework.web.server.ServerWebExchange;
034import org.springframework.web.server.WebFilter;
035import org.springframework.web.server.WebFilterChain;
036import org.springframework.web.server.WebSession;
037
038/**
039 * A {@link WebFilter} for tracing HTTP requests.
040 *
041 * @author Andy Wilkinson
042 * @since 2.0.0
043 */
044public class HttpTraceWebFilter implements WebFilter, Ordered {
045
046        private static final Object NONE = new Object();
047
048        // Not LOWEST_PRECEDENCE, but near the end, so it has a good chance of catching all
049        // enriched headers, but users can add stuff after this if they want to
050        private int order = Ordered.LOWEST_PRECEDENCE - 10;
051
052        private final HttpTraceRepository repository;
053
054        private final HttpExchangeTracer tracer;
055
056        private final Set<Include> includes;
057
058        public HttpTraceWebFilter(HttpTraceRepository repository, HttpExchangeTracer tracer,
059                        Set<Include> includes) {
060                this.repository = repository;
061                this.tracer = tracer;
062                this.includes = includes;
063        }
064
065        @Override
066        public int getOrder() {
067                return this.order;
068        }
069
070        public void setOrder(int order) {
071                this.order = order;
072        }
073
074        @Override
075        public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
076                Mono<?> principal = (this.includes.contains(Include.PRINCIPAL)
077                                ? exchange.getPrincipal().cast(Object.class).defaultIfEmpty(NONE)
078                                : Mono.just(NONE));
079                Mono<?> session = (this.includes.contains(Include.SESSION_ID)
080                                ? exchange.getSession() : Mono.just(NONE));
081                return Mono.zip(principal, session)
082                                .flatMap((tuple) -> filter(exchange, chain,
083                                                asType(tuple.getT1(), Principal.class),
084                                                asType(tuple.getT2(), WebSession.class)));
085        }
086
087        private <T> T asType(Object object, Class<T> type) {
088                if (type.isInstance(object)) {
089                        return type.cast(object);
090                }
091                return null;
092        }
093
094        private Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain,
095                        Principal principal, WebSession session) {
096                ServerWebExchangeTraceableRequest request = new ServerWebExchangeTraceableRequest(
097                                exchange);
098                HttpTrace trace = this.tracer.receivedRequest(request);
099                return chain.filter(exchange).doAfterSuccessOrError((aVoid, ex) -> {
100                        TraceableServerHttpResponse response = new TraceableServerHttpResponse(
101                                        (ex != null) ? new CustomStatusResponseDecorator(ex,
102                                                        exchange.getResponse()) : exchange.getResponse());
103                        this.tracer.sendingResponse(trace, response, () -> principal,
104                                        () -> getStartedSessionId(session));
105                        this.repository.add(trace);
106                });
107        }
108
109        private String getStartedSessionId(WebSession session) {
110                return (session != null && session.isStarted()) ? session.getId() : null;
111        }
112
113        private static final class CustomStatusResponseDecorator
114                        extends ServerHttpResponseDecorator {
115
116                private final HttpStatus status;
117
118                private CustomStatusResponseDecorator(Throwable ex, ServerHttpResponse delegate) {
119                        super(delegate);
120                        this.status = (ex instanceof ResponseStatusException)
121                                        ? ((ResponseStatusException) ex).getStatus()
122                                        : HttpStatus.INTERNAL_SERVER_ERROR;
123                }
124
125                @Override
126                public HttpStatus getStatusCode() {
127                        return this.status;
128                }
129
130        }
131
132}