001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.function.server;
018
019import java.net.InetSocketAddress;
020import java.net.URI;
021import java.security.Principal;
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.Collections;
025import java.util.EnumSet;
026import java.util.HashMap;
027import java.util.HashSet;
028import java.util.LinkedHashMap;
029import java.util.List;
030import java.util.Map;
031import java.util.Optional;
032import java.util.Set;
033import java.util.concurrent.ConcurrentHashMap;
034import java.util.function.Function;
035import java.util.function.Predicate;
036
037import org.apache.commons.logging.Log;
038import org.apache.commons.logging.LogFactory;
039import reactor.core.publisher.Flux;
040import reactor.core.publisher.Mono;
041
042import org.springframework.core.ParameterizedTypeReference;
043import org.springframework.http.HttpCookie;
044import org.springframework.http.HttpHeaders;
045import org.springframework.http.HttpMethod;
046import org.springframework.http.MediaType;
047import org.springframework.http.codec.HttpMessageReader;
048import org.springframework.http.codec.multipart.Part;
049import org.springframework.http.server.PathContainer;
050import org.springframework.http.server.reactive.ServerHttpRequest;
051import org.springframework.lang.NonNull;
052import org.springframework.lang.Nullable;
053import org.springframework.util.Assert;
054import org.springframework.util.MultiValueMap;
055import org.springframework.web.cors.reactive.CorsUtils;
056import org.springframework.web.reactive.function.BodyExtractor;
057import org.springframework.web.server.ServerWebExchange;
058import org.springframework.web.server.WebSession;
059import org.springframework.web.util.UriBuilder;
060import org.springframework.web.util.UriUtils;
061import org.springframework.web.util.pattern.PathPattern;
062import org.springframework.web.util.pattern.PathPatternParser;
063
064/**
065 * Implementations of {@link RequestPredicate} that implement various useful
066 * request matching operations, such as matching based on path, HTTP method, etc.
067 *
068 * @author Arjen Poutsma
069 * @since 5.0
070 */
071public abstract class RequestPredicates {
072
073        private static final Log logger = LogFactory.getLog(RequestPredicates.class);
074
075        /**
076         * Return a {@code RequestPredicate} that always matches.
077         * @return a predicate that always matches
078         */
079        public static RequestPredicate all() {
080                return request -> true;
081        }
082
083        /**
084         * Return a {@code RequestPredicate} that matches if the request's
085         * HTTP method is equal to the given method.
086         * @param httpMethod the HTTP method to match against
087         * @return a predicate that tests against the given HTTP method
088         */
089        public static RequestPredicate method(HttpMethod httpMethod) {
090                return new HttpMethodPredicate(httpMethod);
091        }
092
093        /**
094         * Return a {@code RequestPredicate} that matches if the request's
095         * HTTP method is equal to one the of the given methods.
096         * @param httpMethods the HTTP methods to match against
097         * @return a predicate that tests against the given HTTP methods
098         * @since 5.1
099         */
100        public static RequestPredicate methods(HttpMethod... httpMethods) {
101                return new HttpMethodPredicate(httpMethods);
102        }
103
104        /**
105         * Return a {@code RequestPredicate} that tests the request path
106         * against the given path pattern.
107         * @param pattern the pattern to match to
108         * @return a predicate that tests against the given path pattern
109         */
110        public static RequestPredicate path(String pattern) {
111                Assert.notNull(pattern, "'pattern' must not be null");
112                if (!pattern.isEmpty() && !pattern.startsWith("/")) {
113                        pattern = "/" + pattern;
114                }
115                return pathPredicates(PathPatternParser.defaultInstance).apply(pattern);
116        }
117
118        /**
119         * Return a function that creates new path-matching {@code RequestPredicates}
120         * from pattern Strings using the given {@link PathPatternParser}.
121         * <p>This method can be used to specify a non-default, customized
122         * {@code PathPatternParser} when resolving path patterns.
123         * @param patternParser the parser used to parse patterns given to the returned function
124         * @return a function that resolves a pattern String into a path-matching
125         * {@code RequestPredicates} instance
126         */
127        public static Function<String, RequestPredicate> pathPredicates(PathPatternParser patternParser) {
128                Assert.notNull(patternParser, "PathPatternParser must not be null");
129                return pattern -> new PathPatternPredicate(patternParser.parse(pattern));
130        }
131
132        /**
133         * Return a {@code RequestPredicate} that tests the request's headers
134         * against the given headers predicate.
135         * @param headersPredicate a predicate that tests against the request headers
136         * @return a predicate that tests against the given header predicate
137         */
138        public static RequestPredicate headers(Predicate<ServerRequest.Headers> headersPredicate) {
139                return new HeadersPredicate(headersPredicate);
140        }
141
142        /**
143         * Return a {@code RequestPredicate} that tests if the request's
144         * {@linkplain ServerRequest.Headers#contentType() content type} is
145         * {@linkplain MediaType#includes(MediaType) included} by any of the given media types.
146         * @param mediaTypes the media types to match the request's content type against
147         * @return a predicate that tests the request's content type against the given media types
148         */
149        public static RequestPredicate contentType(MediaType... mediaTypes) {
150                Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
151                return new ContentTypePredicate(mediaTypes);
152        }
153
154        /**
155         * Return a {@code RequestPredicate} that tests if the request's
156         * {@linkplain ServerRequest.Headers#accept() accept} header is
157         * {@linkplain MediaType#isCompatibleWith(MediaType) compatible} with any of the given media types.
158         * @param mediaTypes the media types to match the request's accept header against
159         * @return a predicate that tests the request's accept header against the given media types
160         */
161        public static RequestPredicate accept(MediaType... mediaTypes) {
162                Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
163                return new AcceptPredicate(mediaTypes);
164        }
165
166        /**
167         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code GET}
168         * and the given {@code pattern} matches against the request path.
169         * @param pattern the path pattern to match against
170         * @return a predicate that matches if the request method is GET and if the given pattern
171         * matches against the request path
172         */
173        public static RequestPredicate GET(String pattern) {
174                return method(HttpMethod.GET).and(path(pattern));
175        }
176
177        /**
178         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code HEAD}
179         * and the given {@code pattern} matches against the request path.
180         * @param pattern the path pattern to match against
181         * @return a predicate that matches if the request method is HEAD and if the given pattern
182         * matches against the request path
183         */
184        public static RequestPredicate HEAD(String pattern) {
185                return method(HttpMethod.HEAD).and(path(pattern));
186        }
187
188        /**
189         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code POST}
190         * and the given {@code pattern} matches against the request path.
191         * @param pattern the path pattern to match against
192         * @return a predicate that matches if the request method is POST and if the given pattern
193         * matches against the request path
194         */
195        public static RequestPredicate POST(String pattern) {
196                return method(HttpMethod.POST).and(path(pattern));
197        }
198
199        /**
200         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PUT}
201         * and the given {@code pattern} matches against the request path.
202         * @param pattern the path pattern to match against
203         * @return a predicate that matches if the request method is PUT and if the given pattern
204         * matches against the request path
205         */
206        public static RequestPredicate PUT(String pattern) {
207                return method(HttpMethod.PUT).and(path(pattern));
208        }
209
210        /**
211         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code PATCH}
212         * and the given {@code pattern} matches against the request path.
213         * @param pattern the path pattern to match against
214         * @return a predicate that matches if the request method is PATCH and if the given pattern
215         * matches against the request path
216         */
217        public static RequestPredicate PATCH(String pattern) {
218                return method(HttpMethod.PATCH).and(path(pattern));
219        }
220
221        /**
222         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code DELETE}
223         * and the given {@code pattern} matches against the request path.
224         * @param pattern the path pattern to match against
225         * @return a predicate that matches if the request method is DELETE and if the given pattern
226         * matches against the request path
227         */
228        public static RequestPredicate DELETE(String pattern) {
229                return method(HttpMethod.DELETE).and(path(pattern));
230        }
231
232        /**
233         * Return a {@code RequestPredicate} that matches if request's HTTP method is {@code OPTIONS}
234         * and the given {@code pattern} matches against the request path.
235         * @param pattern the path pattern to match against
236         * @return a predicate that matches if the request method is OPTIONS and if the given pattern
237         * matches against the request path
238         */
239        public static RequestPredicate OPTIONS(String pattern) {
240                return method(HttpMethod.OPTIONS).and(path(pattern));
241        }
242
243        /**
244         * Return a {@code RequestPredicate} that matches if the request's path has the given extension.
245         * @param extension the path extension to match against, ignoring case
246         * @return a predicate that matches if the request's path has the given file extension
247         */
248        public static RequestPredicate pathExtension(String extension) {
249                Assert.notNull(extension, "'extension' must not be null");
250                return new PathExtensionPredicate(extension);
251        }
252
253        /**
254         * Return a {@code RequestPredicate} that matches if the request's path matches the given
255         * predicate.
256         * @param extensionPredicate the predicate to test against the request path extension
257         * @return a predicate that matches if the given predicate matches against the request's path
258         * file extension
259         */
260        public static RequestPredicate pathExtension(Predicate<String> extensionPredicate) {
261                return new PathExtensionPredicate(extensionPredicate);
262        }
263
264        /**
265         * Return a {@code RequestPredicate} that matches if the request's query parameter of the given name
266         * has the given value.
267         * @param name the name of the query parameter to test against
268         * @param value the value of the query parameter to test against
269         * @return a predicate that matches if the query parameter has the given value
270         * @since 5.0.7
271         * @see ServerRequest#queryParam(String)
272         */
273        public static RequestPredicate queryParam(String name, String value) {
274                return new QueryParamPredicate(name, value);
275        }
276
277        /**
278         * Return a {@code RequestPredicate} that tests the request's query parameter of the given name
279         * against the given predicate.
280         * @param name the name of the query parameter to test against
281         * @param predicate the predicate to test against the query parameter value
282         * @return a predicate that matches the given predicate against the query parameter of the given name
283         * @see ServerRequest#queryParam(String)
284         */
285        public static RequestPredicate queryParam(String name, Predicate<String> predicate) {
286                return new QueryParamPredicate(name, predicate);
287        }
288
289
290        private static void traceMatch(String prefix, Object desired, @Nullable Object actual, boolean match) {
291                if (logger.isTraceEnabled()) {
292                        logger.trace(String.format("%s \"%s\" %s against value \"%s\"",
293                                        prefix, desired, match ? "matches" : "does not match", actual));
294                }
295        }
296
297        private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
298                request.attributes().clear();
299                request.attributes().putAll(attributes);
300        }
301
302        private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
303                        Map<String, String> newVariables) {
304
305                if (!newVariables.isEmpty()) {
306                        Map<String, String> mergedVariables = new LinkedHashMap<>(oldVariables);
307                        mergedVariables.putAll(newVariables);
308                        return mergedVariables;
309                }
310                else {
311                        return oldVariables;
312                }
313        }
314
315        private static PathPattern mergePatterns(@Nullable PathPattern oldPattern, PathPattern newPattern) {
316                if (oldPattern != null) {
317                        return oldPattern.combine(newPattern);
318                }
319                else {
320                        return newPattern;
321                }
322
323        }
324
325
326        /**
327         * Receives notifications from the logical structure of request predicates.
328         */
329        public interface Visitor {
330
331                /**
332                 * Receive notification of an HTTP method predicate.
333                 * @param methods the HTTP methods that make up the predicate
334                 * @see RequestPredicates#method(HttpMethod)
335                 */
336                void method(Set<HttpMethod> methods);
337
338                /**
339                 * Receive notification of an path predicate.
340                 * @param pattern the path pattern that makes up the predicate
341                 * @see RequestPredicates#path(String)
342                 */
343                void path(String pattern);
344
345                /**
346                 * Receive notification of an path extension predicate.
347                 * @param extension the path extension that makes up the predicate
348                 * @see RequestPredicates#pathExtension(String)
349                 */
350                void pathExtension(String extension);
351
352                /**
353                 * Receive notification of an HTTP header predicate.
354                 * @param name the name of the HTTP header to check
355                 * @param value the desired value of the HTTP header
356                 * @see RequestPredicates#headers(Predicate)
357                 * @see RequestPredicates#contentType(MediaType...)
358                 * @see RequestPredicates#accept(MediaType...)
359                 */
360                void header(String name, String value);
361
362                /**
363                 * Receive notification of a query parameter predicate.
364                 * @param name the name of the query parameter
365                 * @param value the desired value of the parameter
366                 * @see RequestPredicates#queryParam(String, String)
367                 */
368                void queryParam(String name, String value);
369
370                /**
371                 * Receive first notification of a logical AND predicate.
372                 * The first subsequent notification will contain the left-hand side of the AND-predicate;
373                 * followed by {@link #and()}, followed by the right-hand side, followed by {@link #endAnd()}.
374                 * @see RequestPredicate#and(RequestPredicate)
375                 */
376                void startAnd();
377
378                /**
379                 * Receive "middle" notification of a logical AND predicate.
380                 * The following notification contains the right-hand side, followed by {@link #endAnd()}.
381                 * @see RequestPredicate#and(RequestPredicate)
382                 */
383                void and();
384
385                /**
386                 * Receive last notification of a logical AND predicate.
387                 * @see RequestPredicate#and(RequestPredicate)
388                 */
389                void endAnd();
390
391                /**
392                 * Receive first notification of a logical OR predicate.
393                 * The first subsequent notification will contain the left-hand side of the OR-predicate;
394                 * the second notification contains the right-hand side, followed by {@link #endOr()}.
395                 * @see RequestPredicate#or(RequestPredicate)
396                 */
397                void startOr();
398
399                /**
400                 * Receive "middle" notification of a logical OR predicate.
401                 * The following notification contains the right-hand side, followed by {@link #endOr()}.
402                 * @see RequestPredicate#or(RequestPredicate)
403                 */
404                void or();
405
406                /**
407                 * Receive last notification of a logical OR predicate.
408                 * @see RequestPredicate#or(RequestPredicate)
409                 */
410                void endOr();
411
412                /**
413                 * Receive first notification of a negated predicate.
414                 * The first subsequent notification will contain the negated predicated, followed
415                 * by {@link #endNegate()}.
416                 * @see RequestPredicate#negate()
417                 */
418                void startNegate();
419
420                /**
421                 * Receive last notification of a negated predicate.
422                 * @see RequestPredicate#negate()
423                 */
424                void endNegate();
425
426                /**
427                 * Receive first notification of an unknown predicate.
428                 */
429                void unknown(RequestPredicate predicate);
430        }
431
432
433        private static class HttpMethodPredicate implements RequestPredicate {
434
435                private final Set<HttpMethod> httpMethods;
436
437                public HttpMethodPredicate(HttpMethod httpMethod) {
438                        Assert.notNull(httpMethod, "HttpMethod must not be null");
439                        this.httpMethods = EnumSet.of(httpMethod);
440                }
441
442                public HttpMethodPredicate(HttpMethod... httpMethods) {
443                        Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
444
445                        this.httpMethods = EnumSet.copyOf(Arrays.asList(httpMethods));
446                }
447
448                @Override
449                public boolean test(ServerRequest request) {
450                        HttpMethod method = method(request);
451                        boolean match = this.httpMethods.contains(method);
452                        traceMatch("Method", this.httpMethods, method, match);
453                        return match;
454                }
455
456                @Nullable
457                private static HttpMethod method(ServerRequest request) {
458                        if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
459                                String accessControlRequestMethod =
460                                                request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
461                                return HttpMethod.resolve(accessControlRequestMethod);
462                        }
463                        else {
464                                return request.method();
465                        }
466                }
467
468
469                @Override
470                public void accept(Visitor visitor) {
471                        visitor.method(Collections.unmodifiableSet(this.httpMethods));
472                }
473
474                @Override
475                public String toString() {
476                        if (this.httpMethods.size() == 1) {
477                                return this.httpMethods.iterator().next().toString();
478                        }
479                        else {
480                                return this.httpMethods.toString();
481                        }
482                }
483        }
484
485
486        private static class PathPatternPredicate implements RequestPredicate {
487
488                private final PathPattern pattern;
489
490                public PathPatternPredicate(PathPattern pattern) {
491                        Assert.notNull(pattern, "'pattern' must not be null");
492                        this.pattern = pattern;
493                }
494
495                @Override
496                public boolean test(ServerRequest request) {
497                        PathContainer pathContainer = request.pathContainer();
498                        PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
499                        traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
500                        if (info != null) {
501                                mergeAttributes(request, info.getUriVariables(), this.pattern);
502                                return true;
503                        }
504                        else {
505                                return false;
506                        }
507                }
508
509                private static void mergeAttributes(ServerRequest request, Map<String, String> variables,
510                                PathPattern pattern) {
511                        Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
512                        request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
513                                                Collections.unmodifiableMap(pathVariables));
514
515                        pattern = mergePatterns(
516                                        (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
517                                        pattern);
518                        request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
519                }
520
521                @Override
522                public Optional<ServerRequest> nest(ServerRequest request) {
523                        return Optional.ofNullable(this.pattern.matchStartOfPath(request.pathContainer()))
524                                        .map(info -> new SubPathServerRequestWrapper(request, info, this.pattern));
525                }
526
527                @Override
528                public void accept(Visitor visitor) {
529                        visitor.path(this.pattern.getPatternString());
530                }
531
532                @Override
533                public String toString() {
534                        return this.pattern.getPatternString();
535                }
536        }
537
538
539        private static class HeadersPredicate implements RequestPredicate {
540
541                private final Predicate<ServerRequest.Headers> headersPredicate;
542
543                public HeadersPredicate(Predicate<ServerRequest.Headers> headersPredicate) {
544                        Assert.notNull(headersPredicate, "Predicate must not be null");
545                        this.headersPredicate = headersPredicate;
546                }
547
548                @Override
549                public boolean test(ServerRequest request) {
550                        if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
551                                return true;
552                        }
553                        else {
554                                return this.headersPredicate.test(request.headers());
555                        }
556                }
557
558                @Override
559                public String toString() {
560                        return this.headersPredicate.toString();
561                }
562        }
563
564        private static class ContentTypePredicate extends HeadersPredicate {
565
566                private final Set<MediaType> mediaTypes;
567
568                public ContentTypePredicate(MediaType... mediaTypes) {
569                        this(new HashSet<>(Arrays.asList(mediaTypes)));
570                }
571
572                private ContentTypePredicate(Set<MediaType> mediaTypes) {
573                        super(headers -> {
574                                MediaType contentType =
575                                                headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
576                                boolean match = mediaTypes.stream()
577                                                .anyMatch(mediaType -> mediaType.includes(contentType));
578                                traceMatch("Content-Type", mediaTypes, contentType, match);
579                                return match;
580                        });
581                        this.mediaTypes = mediaTypes;
582                }
583
584                @Override
585                public void accept(Visitor visitor) {
586                        visitor.header(HttpHeaders.CONTENT_TYPE,
587                                        (this.mediaTypes.size() == 1) ?
588                                                        this.mediaTypes.iterator().next().toString() :
589                                                        this.mediaTypes.toString());
590                }
591
592                @Override
593                public String toString() {
594                        return String.format("Content-Type: %s",
595                                        (this.mediaTypes.size() == 1) ?
596                                                        this.mediaTypes.iterator().next().toString() :
597                                                        this.mediaTypes.toString());
598                }
599        }
600
601        private static class AcceptPredicate extends HeadersPredicate {
602
603                private final Set<MediaType> mediaTypes;
604
605                public AcceptPredicate(MediaType... mediaTypes) {
606                        this(new HashSet<>(Arrays.asList(mediaTypes)));
607                }
608
609                private AcceptPredicate(Set<MediaType> mediaTypes) {
610                        super(headers -> {
611                                List<MediaType> acceptedMediaTypes = acceptedMediaTypes(headers);
612                                boolean match = acceptedMediaTypes.stream()
613                                                .anyMatch(acceptedMediaType -> mediaTypes.stream()
614                                                                .anyMatch(acceptedMediaType::isCompatibleWith));
615                                traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
616                                return match;
617                        });
618                        this.mediaTypes = mediaTypes;
619                }
620
621                @NonNull
622                private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
623                        List<MediaType> acceptedMediaTypes = headers.accept();
624                        if (acceptedMediaTypes.isEmpty()) {
625                                acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
626                        }
627                        else {
628                                MediaType.sortBySpecificityAndQuality(acceptedMediaTypes);
629                        }
630                        return acceptedMediaTypes;
631                }
632
633                @Override
634                public void accept(Visitor visitor) {
635                        visitor.header(HttpHeaders.ACCEPT,
636                                        (this.mediaTypes.size() == 1) ?
637                                                        this.mediaTypes.iterator().next().toString() :
638                                                        this.mediaTypes.toString());
639                }
640
641                @Override
642                public String toString() {
643                        return String.format("Accept: %s",
644                                        (this.mediaTypes.size() == 1) ?
645                                                        this.mediaTypes.iterator().next().toString() :
646                                                        this.mediaTypes.toString());
647                }
648        }
649
650
651        private static class PathExtensionPredicate implements RequestPredicate {
652
653                private final Predicate<String> extensionPredicate;
654
655                @Nullable
656                private final String extension;
657
658                public PathExtensionPredicate(Predicate<String> extensionPredicate) {
659                        Assert.notNull(extensionPredicate, "Predicate must not be null");
660                        this.extensionPredicate = extensionPredicate;
661                        this.extension = null;
662                }
663
664                public PathExtensionPredicate(String extension) {
665                        Assert.notNull(extension, "Extension must not be null");
666
667                        this.extensionPredicate = s -> {
668                                boolean match = extension.equalsIgnoreCase(s);
669                                traceMatch("Extension", extension, s, match);
670                                return match;
671                        };
672                        this.extension = extension;
673                }
674
675                @Override
676                public boolean test(ServerRequest request) {
677                        String pathExtension = UriUtils.extractFileExtension(request.path());
678                        return this.extensionPredicate.test(pathExtension);
679                }
680
681                @Override
682                public void accept(Visitor visitor) {
683                        visitor.pathExtension(
684                                        (this.extension != null) ?
685                                                        this.extension :
686                                                        this.extensionPredicate.toString());
687                }
688
689                @Override
690                public String toString() {
691                        return String.format("*.%s",
692                                        (this.extension != null) ?
693                                                        this.extension :
694                                                        this.extensionPredicate);
695                }
696
697        }
698
699
700        private static class QueryParamPredicate implements RequestPredicate {
701
702                private final String name;
703
704                private final Predicate<String> valuePredicate;
705
706                @Nullable
707                private final String value;
708
709                public QueryParamPredicate(String name, Predicate<String> valuePredicate) {
710                        Assert.notNull(name, "Name must not be null");
711                        Assert.notNull(valuePredicate, "Predicate must not be null");
712                        this.name = name;
713                        this.valuePredicate = valuePredicate;
714                        this.value = null;
715                }
716
717                public QueryParamPredicate(String name, String value) {
718                        Assert.notNull(name, "Name must not be null");
719                        Assert.notNull(value, "Value must not be null");
720                        this.name = name;
721                        this.valuePredicate = value::equals;
722                        this.value = value;
723                }
724
725                @Override
726                public boolean test(ServerRequest request) {
727                        Optional<String> s = request.queryParam(this.name);
728                        return s.filter(this.valuePredicate).isPresent();
729                }
730
731                @Override
732                public void accept(Visitor visitor) {
733                        visitor.queryParam(this.name,
734                                        (this.value != null) ?
735                                                        this.value :
736                                                        this.valuePredicate.toString());
737                }
738
739                @Override
740                public String toString() {
741                        return String.format("?%s %s", this.name,
742                                        (this.value != null) ?
743                                                        this.value :
744                                                        this.valuePredicate);
745                }
746        }
747
748
749        /**
750         * {@link RequestPredicate} for where both {@code left} and {@code right} predicates
751         * must match.
752         */
753        static class AndRequestPredicate implements RequestPredicate {
754
755                private final RequestPredicate left;
756
757                private final RequestPredicate right;
758
759                public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
760                        Assert.notNull(left, "Left RequestPredicate must not be null");
761                        Assert.notNull(right, "Right RequestPredicate must not be null");
762                        this.left = left;
763                        this.right = right;
764                }
765
766                @Override
767                public boolean test(ServerRequest request) {
768                        Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
769
770                        if (this.left.test(request) && this.right.test(request)) {
771                                return true;
772                        }
773                        restoreAttributes(request, oldAttributes);
774                        return false;
775                }
776
777                @Override
778                public Optional<ServerRequest> nest(ServerRequest request) {
779                        return this.left.nest(request).flatMap(this.right::nest);
780                }
781
782                @Override
783                public void accept(Visitor visitor) {
784                        visitor.startAnd();
785                        this.left.accept(visitor);
786                        visitor.and();
787                        this.right.accept(visitor);
788                        visitor.endAnd();
789                }
790
791                @Override
792                public String toString() {
793                        return String.format("(%s && %s)", this.left, this.right);
794                }
795        }
796
797        /**
798         * {@link RequestPredicate} that negates a delegate predicate.
799         */
800        static class NegateRequestPredicate implements RequestPredicate {
801                private final RequestPredicate delegate;
802
803                public NegateRequestPredicate(RequestPredicate delegate) {
804                        Assert.notNull(delegate, "Delegate must not be null");
805                        this.delegate = delegate;
806                }
807
808                @Override
809                public boolean test(ServerRequest request) {
810                        Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
811                        boolean result = !this.delegate.test(request);
812                        if (!result) {
813                                restoreAttributes(request, oldAttributes);
814                        }
815                        return result;
816                }
817
818                @Override
819                public void accept(Visitor visitor) {
820                        visitor.startNegate();
821                        this.delegate.accept(visitor);
822                        visitor.endNegate();
823                }
824
825                @Override
826                public String toString() {
827                        return "!" + this.delegate.toString();
828                }
829        }
830
831        /**
832         * {@link RequestPredicate} where either {@code left} or {@code right} predicates
833         * may match.
834         */
835        static class OrRequestPredicate implements RequestPredicate {
836
837                private final RequestPredicate left;
838
839                private final RequestPredicate right;
840
841                public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
842                        Assert.notNull(left, "Left RequestPredicate must not be null");
843                        Assert.notNull(right, "Right RequestPredicate must not be null");
844                        this.left = left;
845                        this.right = right;
846                }
847
848                @Override
849                public boolean test(ServerRequest request) {
850                        Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
851
852                        if (this.left.test(request)) {
853                                return true;
854                        }
855                        else {
856                                restoreAttributes(request, oldAttributes);
857                                if (this.right.test(request)) {
858                                        return true;
859                                }
860                        }
861                        restoreAttributes(request, oldAttributes);
862                        return false;
863                }
864
865                @Override
866                public Optional<ServerRequest> nest(ServerRequest request) {
867                        Optional<ServerRequest> leftResult = this.left.nest(request);
868                        if (leftResult.isPresent()) {
869                                return leftResult;
870                        }
871                        else {
872                                return this.right.nest(request);
873                        }
874                }
875
876                @Override
877                public void accept(Visitor visitor) {
878                        visitor.startOr();
879                        this.left.accept(visitor);
880                        visitor.or();
881                        this.right.accept(visitor);
882                        visitor.endOr();
883                }
884
885
886                @Override
887                public String toString() {
888                        return String.format("(%s || %s)", this.left, this.right);
889                }
890        }
891
892
893        private static class SubPathServerRequestWrapper implements ServerRequest {
894
895                private final ServerRequest request;
896
897                private final PathContainer pathContainer;
898
899                private final Map<String, Object> attributes;
900
901                public SubPathServerRequestWrapper(ServerRequest request,
902                                PathPattern.PathRemainingMatchInfo info, PathPattern pattern) {
903                        this.request = request;
904                        this.pathContainer = new SubPathContainer(info.getPathRemaining());
905                        this.attributes = mergeAttributes(request, info.getUriVariables(), pattern);
906                }
907
908                private static Map<String, Object> mergeAttributes(ServerRequest request,
909                Map<String, String> pathVariables, PathPattern pattern) {
910                        Map<String, Object> result = new ConcurrentHashMap<>(request.attributes());
911
912                        result.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
913                                        mergePathVariables(request.pathVariables(), pathVariables));
914
915                        pattern = mergePatterns(
916                                        (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
917                                        pattern);
918                        result.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
919                        return result;
920                }
921
922                @Override
923                public HttpMethod method() {
924                        return this.request.method();
925                }
926
927                @Override
928                public String methodName() {
929                        return this.request.methodName();
930                }
931
932                @Override
933                public URI uri() {
934                        return this.request.uri();
935                }
936
937                @Override
938                public UriBuilder uriBuilder() {
939                        return this.request.uriBuilder();
940                }
941
942                @Override
943                public String path() {
944                        return this.pathContainer.value();
945                }
946
947                @Override
948                public PathContainer pathContainer() {
949                        return this.pathContainer;
950                }
951
952                @Override
953                public Headers headers() {
954                        return this.request.headers();
955                }
956
957                @Override
958                public MultiValueMap<String, HttpCookie> cookies() {
959                        return this.request.cookies();
960                }
961
962                @Override
963                public Optional<InetSocketAddress> remoteAddress() {
964                        return this.request.remoteAddress();
965                }
966
967                @Override
968                public Optional<InetSocketAddress> localAddress() {
969                        return this.request.localAddress();
970                }
971
972                @Override
973                public List<HttpMessageReader<?>> messageReaders() {
974                        return this.request.messageReaders();
975                }
976
977                @Override
978                public <T> T body(BodyExtractor<T, ? super ServerHttpRequest> extractor) {
979                        return this.request.body(extractor);
980                }
981
982                @Override
983                public <T> T body(BodyExtractor<T, ? super ServerHttpRequest> extractor, Map<String, Object> hints) {
984                        return this.request.body(extractor, hints);
985                }
986
987                @Override
988                public <T> Mono<T> bodyToMono(Class<? extends T> elementClass) {
989                        return this.request.bodyToMono(elementClass);
990                }
991
992                @Override
993                public <T> Mono<T> bodyToMono(ParameterizedTypeReference<T> typeReference) {
994                        return this.request.bodyToMono(typeReference);
995                }
996
997                @Override
998                public <T> Flux<T> bodyToFlux(Class<? extends T> elementClass) {
999                        return this.request.bodyToFlux(elementClass);
1000                }
1001
1002                @Override
1003                public <T> Flux<T> bodyToFlux(ParameterizedTypeReference<T> typeReference) {
1004                        return this.request.bodyToFlux(typeReference);
1005                }
1006
1007                @Override
1008                public Map<String, Object> attributes() {
1009                        return this.attributes;
1010                }
1011
1012                @Override
1013                public Optional<String> queryParam(String name) {
1014                        return this.request.queryParam(name);
1015                }
1016
1017                @Override
1018                public MultiValueMap<String, String> queryParams() {
1019                        return this.request.queryParams();
1020                }
1021
1022                @Override
1023                @SuppressWarnings("unchecked")
1024                public Map<String, String> pathVariables() {
1025                        return (Map<String, String>) this.attributes.getOrDefault(
1026                                        RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Collections.emptyMap());
1027
1028                }
1029
1030                @Override
1031                public Mono<WebSession> session() {
1032                        return this.request.session();
1033                }
1034
1035                @Override
1036                public Mono<? extends Principal> principal() {
1037                        return this.request.principal();
1038                }
1039
1040                @Override
1041                public Mono<MultiValueMap<String, String>> formData() {
1042                        return this.request.formData();
1043                }
1044
1045                @Override
1046                public Mono<MultiValueMap<String, Part>> multipartData() {
1047                        return this.request.multipartData();
1048                }
1049
1050                @Override
1051                public ServerWebExchange exchange() {
1052                        return this.request.exchange();
1053                }
1054
1055                @Override
1056                public String toString() {
1057                        return method() + " " +  path();
1058                }
1059
1060                private static class SubPathContainer implements PathContainer {
1061
1062                        private static final PathContainer.Separator SEPARATOR = () -> "/";
1063
1064
1065                        private final String value;
1066
1067                        private final List<Element> elements;
1068
1069                        public SubPathContainer(PathContainer original) {
1070                                this.value = prefixWithSlash(original.value());
1071                                this.elements = prependWithSeparator(original.elements());
1072                        }
1073
1074                        private static String prefixWithSlash(String path) {
1075                                if (!path.startsWith("/")) {
1076                                        path = "/" + path;
1077                                }
1078                                return path;
1079                        }
1080
1081                        private static List<Element> prependWithSeparator(List<Element> elements) {
1082                                List<Element> result = new ArrayList<>(elements);
1083                                if (result.isEmpty() || !(result.get(0) instanceof Separator)) {
1084                                        result.add(0, SEPARATOR);
1085                                }
1086                                return Collections.unmodifiableList(result);
1087                        }
1088
1089
1090                        @Override
1091                        public String value() {
1092                                return this.value;
1093                        }
1094
1095                        @Override
1096                        public List<Element> elements() {
1097                                return this.elements;
1098                        }
1099                }
1100        }
1101
1102}