001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.filter;
018
019import java.io.IOException;
020import java.util.Collections;
021import java.util.Enumeration;
022import java.util.List;
023import java.util.Locale;
024import java.util.Map;
025import java.util.Set;
026import javax.servlet.FilterChain;
027import javax.servlet.ServletException;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpServletRequestWrapper;
030import javax.servlet.http.HttpServletResponse;
031import javax.servlet.http.HttpServletResponseWrapper;
032
033import org.springframework.http.HttpRequest;
034import org.springframework.http.HttpStatus;
035import org.springframework.http.server.ServletServerHttpRequest;
036import org.springframework.util.CollectionUtils;
037import org.springframework.util.LinkedCaseInsensitiveMap;
038import org.springframework.util.StringUtils;
039import org.springframework.web.util.UriComponents;
040import org.springframework.web.util.UriComponentsBuilder;
041import org.springframework.web.util.UrlPathHelper;
042
043/**
044 * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap
045 * and override the following from the request and response:
046 * {@link HttpServletRequest#getServerName() getServerName()},
047 * {@link HttpServletRequest#getServerPort() getServerPort()},
048 * {@link HttpServletRequest#getScheme() getScheme()},
049 * {@link HttpServletRequest#isSecure() isSecure()}, and
050 * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}.
051 * In effect the wrapped request and response reflect the client-originated
052 * protocol and address.
053 *
054 * <p><strong>Note:</strong> This filter can also be used in a
055 * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*"
056 * headers are only eliminated without being used.
057 *
058 * @author Rossen Stoyanchev
059 * @author Edd煤 Mel茅ndez
060 * @author Rob Winch
061 * @since 4.3
062 * @see <a href="https://tools.ietf.org/html/rfc7239">https://tools.ietf.org/html/rfc7239</a>
063 */
064public class ForwardedHeaderFilter extends OncePerRequestFilter {
065
066        private static final Set<String> FORWARDED_HEADER_NAMES =
067                        Collections.newSetFromMap(new LinkedCaseInsensitiveMap<Boolean>(5, Locale.ENGLISH));
068
069        static {
070                FORWARDED_HEADER_NAMES.add("Forwarded");
071                FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
072                FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
073                FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
074                FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
075        }
076
077
078        private boolean removeOnly;
079
080        private boolean relativeRedirects;
081
082
083        /**
084         * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are
085         * removed only and the information in them ignored.
086         * @param removeOnly whether to discard and ignore forwarded headers
087         * @since 4.3.9
088         */
089        public void setRemoveOnly(boolean removeOnly) {
090                this.removeOnly = removeOnly;
091        }
092
093        /**
094         * Use this property to enable relative redirects as explained in
095         * {@link RelativeRedirectFilter}, and also using the same response wrapper
096         * as that filter does, or if both are configured, only one will wrap.
097         * <p>By default, if this property is set to false, in which case calls to
098         * {@link HttpServletResponse#sendRedirect(String)} are overridden in order
099         * to turn relative into absolute URLs, also taking into account forwarded
100         * headers.
101         * @param relativeRedirects whether to use relative redirects
102         * @since 4.3.10
103         */
104        public void setRelativeRedirects(boolean relativeRedirects) {
105                this.relativeRedirects = relativeRedirects;
106        }
107
108
109        @Override
110        protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
111                Enumeration<String> names = request.getHeaderNames();
112                while (names.hasMoreElements()) {
113                        String name = names.nextElement();
114                        if (FORWARDED_HEADER_NAMES.contains(name)) {
115                                return false;
116                        }
117                }
118                return true;
119        }
120
121        @Override
122        protected boolean shouldNotFilterAsyncDispatch() {
123                return false;
124        }
125
126        @Override
127        protected boolean shouldNotFilterErrorDispatch() {
128                return false;
129        }
130
131        @Override
132        protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
133                        FilterChain filterChain) throws ServletException, IOException {
134
135                if (this.removeOnly) {
136                        ForwardedHeaderRemovingRequest theRequest = new ForwardedHeaderRemovingRequest(request);
137                        filterChain.doFilter(theRequest, response);
138                }
139                else {
140                        HttpServletRequest theRequest = new ForwardedHeaderExtractingRequest(request);
141                        HttpServletResponse theResponse = (this.relativeRedirects ?
142                                        RelativeRedirectResponseWrapper.wrapIfNecessary(response, HttpStatus.SEE_OTHER) :
143                                        new ForwardedHeaderExtractingResponse(response, theRequest));
144                        filterChain.doFilter(theRequest, theResponse);
145                }
146        }
147
148
149        /**
150         * Hide "Forwarded" or "X-Forwarded-*" headers.
151         */
152        private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
153
154                private final Map<String, List<String>> headers;
155
156                public ForwardedHeaderRemovingRequest(HttpServletRequest request) {
157                        super(request);
158                        this.headers = initHeaders(request);
159                }
160
161                private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
162                        Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<List<String>>(Locale.ENGLISH);
163                        Enumeration<String> names = request.getHeaderNames();
164                        while (names.hasMoreElements()) {
165                                String name = names.nextElement();
166                                if (!FORWARDED_HEADER_NAMES.contains(name)) {
167                                        headers.put(name, Collections.list(request.getHeaders(name)));
168                                }
169                        }
170                        return headers;
171                }
172
173                // Override header accessors to not expose forwarded headers
174
175                @Override
176                public String getHeader(String name) {
177                        List<String> value = this.headers.get(name);
178                        return (CollectionUtils.isEmpty(value) ? null : value.get(0));
179                }
180
181                @Override
182                public Enumeration<String> getHeaders(String name) {
183                        List<String> value = this.headers.get(name);
184                        return (Collections.enumeration(value != null ? value : Collections.<String>emptySet()));
185                }
186
187                @Override
188                public Enumeration<String> getHeaderNames() {
189                        return Collections.enumeration(this.headers.keySet());
190                }
191        }
192
193
194        /**
195         * Extract and use "Forwarded" or "X-Forwarded-*" headers.
196         */
197        private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRemovingRequest {
198
199                private final String scheme;
200
201                private final boolean secure;
202
203                private final String host;
204
205                private final int port;
206
207                private final String contextPath;
208
209                private final String requestUri;
210
211                private final String requestUrl;
212
213                public ForwardedHeaderExtractingRequest(HttpServletRequest request) {
214                        super(request);
215
216                        HttpRequest httpRequest = new ServletServerHttpRequest(request);
217                        UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build();
218                        int port = uriComponents.getPort();
219
220                        this.scheme = uriComponents.getScheme();
221                        this.secure = "https".equals(scheme);
222                        this.host = uriComponents.getHost();
223                        this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
224
225                        String prefix = getForwardedPrefix(request);
226                        this.contextPath = (prefix != null ? prefix : request.getContextPath());
227                        this.requestUri = this.contextPath + UrlPathHelper.rawPathInstance.getPathWithinApplication(request);
228                        this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri;
229                }
230
231                private static String getForwardedPrefix(HttpServletRequest request) {
232                        String prefix = null;
233                        Enumeration<String> names = request.getHeaderNames();
234                        while (names.hasMoreElements()) {
235                                String name = names.nextElement();
236                                if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
237                                        prefix = request.getHeader(name);
238                                }
239                        }
240                        if (prefix != null) {
241                                while (prefix.endsWith("/")) {
242                                        prefix = prefix.substring(0, prefix.length() - 1);
243                                }
244                        }
245                        return prefix;
246                }
247
248                @Override
249                public String getScheme() {
250                        return this.scheme;
251                }
252
253                @Override
254                public String getServerName() {
255                        return this.host;
256                }
257
258                @Override
259                public int getServerPort() {
260                        return this.port;
261                }
262
263                @Override
264                public boolean isSecure() {
265                        return this.secure;
266                }
267
268                @Override
269                public String getContextPath() {
270                        return this.contextPath;
271                }
272
273                @Override
274                public String getRequestURI() {
275                        return this.requestUri;
276                }
277
278                @Override
279                public StringBuffer getRequestURL() {
280                        return new StringBuffer(this.requestUrl);
281                }
282        }
283
284
285        private static class ForwardedHeaderExtractingResponse extends HttpServletResponseWrapper {
286
287                private static final String FOLDER_SEPARATOR = "/";
288
289                private final HttpServletRequest request;
290
291                public ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) {
292                        super(response);
293                        this.request = request;
294                }
295
296                @Override
297                public void sendRedirect(String location) throws IOException {
298
299                        UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(location);
300                        UriComponents uriComponents = builder.build();
301
302                        // Absolute location
303                        if (uriComponents.getScheme() != null) {
304                                super.sendRedirect(location);
305                                return;
306                        }
307
308                        // Network-path reference
309                        if (location.startsWith("//")) {
310                                String scheme = this.request.getScheme();
311                                super.sendRedirect(builder.scheme(scheme).toUriString());
312                                return;
313                        }
314
315                        String path = uriComponents.getPath();
316                        if (path != null) {
317                                // Relative to Servlet container root or to current request
318                                path = (path.startsWith(FOLDER_SEPARATOR) ? path :
319                                                StringUtils.applyRelativePath(this.request.getRequestURI(), path));
320                        }
321
322                        String result = UriComponentsBuilder
323                                        .fromHttpRequest(new ServletServerHttpRequest(this.request))
324                                        .replacePath(path)
325                                        .replaceQuery(uriComponents.getQuery())
326                                        .fragment(uriComponents.getFragment())
327                                        .build().normalize().toUriString();
328
329                        super.sendRedirect(result);
330                }
331        }
332
333}