001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.BufferedReader;
020import java.io.ByteArrayInputStream;
021import java.io.IOException;
022import java.io.InputStream;
023import java.io.InputStreamReader;
024import java.io.Reader;
025import java.io.StringReader;
026import java.io.UnsupportedEncodingException;
027import java.security.Principal;
028import java.text.ParseException;
029import java.text.SimpleDateFormat;
030import java.util.Collection;
031import java.util.Collections;
032import java.util.Date;
033import java.util.Enumeration;
034import java.util.HashSet;
035import java.util.LinkedHashMap;
036import java.util.LinkedHashSet;
037import java.util.LinkedList;
038import java.util.List;
039import java.util.Locale;
040import java.util.Map;
041import java.util.Set;
042import java.util.TimeZone;
043import javax.servlet.AsyncContext;
044import javax.servlet.DispatcherType;
045import javax.servlet.RequestDispatcher;
046import javax.servlet.ServletContext;
047import javax.servlet.ServletException;
048import javax.servlet.ServletInputStream;
049import javax.servlet.ServletRequest;
050import javax.servlet.ServletResponse;
051import javax.servlet.http.Cookie;
052import javax.servlet.http.HttpServletRequest;
053import javax.servlet.http.HttpServletResponse;
054import javax.servlet.http.HttpSession;
055import javax.servlet.http.Part;
056
057import org.springframework.http.MediaType;
058import org.springframework.util.Assert;
059import org.springframework.util.LinkedCaseInsensitiveMap;
060import org.springframework.util.LinkedMultiValueMap;
061import org.springframework.util.MultiValueMap;
062import org.springframework.util.StreamUtils;
063import org.springframework.util.StringUtils;
064
065/**
066 * Mock implementation of the {@link javax.servlet.http.HttpServletRequest} interface.
067 *
068 * <p>The default, preferred {@link Locale} for the <em>server</em> mocked by this request
069 * is {@link Locale#ENGLISH}. This value can be changed via {@link #addPreferredLocale}
070 * or {@link #setPreferredLocales}.
071 *
072 * <p>As of Spring Framework 4.0, this set of mocks is designed on a Servlet 3.0 baseline.
073 *
074 * @author Juergen Hoeller
075 * @author Rod Johnson
076 * @author Rick Evans
077 * @author Mark Fisher
078 * @author Chris Beams
079 * @author Sam Brannen
080 * @author Brian Clozel
081 * @since 1.0.2
082 */
083public class MockHttpServletRequest implements HttpServletRequest {
084
085        private static final String HTTP = "http";
086
087        private static final String HTTPS = "https";
088
089        private static final String CONTENT_TYPE_HEADER = "Content-Type";
090
091        private static final String HOST_HEADER = "Host";
092
093        private static final String CHARSET_PREFIX = "charset=";
094
095        private static final TimeZone GMT = TimeZone.getTimeZone("GMT");
096
097        private static final ServletInputStream EMPTY_SERVLET_INPUT_STREAM =
098                        new DelegatingServletInputStream(StreamUtils.emptyInput());
099
100        private static final BufferedReader EMPTY_BUFFERED_READER =
101                        new BufferedReader(new StringReader(""));
102
103        /**
104         * Date formats as specified in the HTTP RFC.
105         * @see <a href="https://tools.ietf.org/html/rfc7231#section-7.1.1.1">Section 7.1.1.1 of RFC 7231</a>
106         */
107        private static final String[] DATE_FORMATS = new String[] {
108                        "EEE, dd MMM yyyy HH:mm:ss zzz",
109                        "EEE, dd-MMM-yy HH:mm:ss zzz",
110                        "EEE MMM dd HH:mm:ss yyyy"
111        };
112
113
114        // ---------------------------------------------------------------------
115        // Public constants
116        // ---------------------------------------------------------------------
117
118        /**
119         * The default protocol: 'HTTP/1.1'.
120         * @since 4.3.7
121         */
122        public static final String DEFAULT_PROTOCOL = "HTTP/1.1";
123
124        /**
125         * The default scheme: 'http'.
126         * @since 4.3.7
127         */
128        public static final String DEFAULT_SCHEME = HTTP;
129
130        /**
131         * The default server address: '127.0.0.1'.
132         */
133        public static final String DEFAULT_SERVER_ADDR = "127.0.0.1";
134
135        /**
136         * The default server name: 'localhost'.
137         */
138        public static final String DEFAULT_SERVER_NAME = "localhost";
139
140        /**
141         * The default server port: '80'.
142         */
143        public static final int DEFAULT_SERVER_PORT = 80;
144
145        /**
146         * The default remote address: '127.0.0.1'.
147         */
148        public static final String DEFAULT_REMOTE_ADDR = "127.0.0.1";
149
150        /**
151         * The default remote host: 'localhost'.
152         */
153        public static final String DEFAULT_REMOTE_HOST = "localhost";
154
155
156        // ---------------------------------------------------------------------
157        // Lifecycle properties
158        // ---------------------------------------------------------------------
159
160        private final ServletContext servletContext;
161
162        private boolean active = true;
163
164
165        // ---------------------------------------------------------------------
166        // ServletRequest properties
167        // ---------------------------------------------------------------------
168
169        private final Map<String, Object> attributes = new LinkedHashMap<String, Object>();
170
171        private String characterEncoding;
172
173        private byte[] content;
174
175        private String contentType;
176
177        private final Map<String, String[]> parameters = new LinkedHashMap<String, String[]>();
178
179        private String protocol = DEFAULT_PROTOCOL;
180
181        private String scheme = DEFAULT_SCHEME;
182
183        private String serverName = DEFAULT_SERVER_NAME;
184
185        private int serverPort = DEFAULT_SERVER_PORT;
186
187        private String remoteAddr = DEFAULT_REMOTE_ADDR;
188
189        private String remoteHost = DEFAULT_REMOTE_HOST;
190
191        /** List of locales in descending order */
192        private final List<Locale> locales = new LinkedList<Locale>();
193
194        private boolean secure = false;
195
196        private int remotePort = DEFAULT_SERVER_PORT;
197
198        private String localName = DEFAULT_SERVER_NAME;
199
200        private String localAddr = DEFAULT_SERVER_ADDR;
201
202        private int localPort = DEFAULT_SERVER_PORT;
203
204        private boolean asyncStarted = false;
205
206        private boolean asyncSupported = false;
207
208        private MockAsyncContext asyncContext;
209
210        private DispatcherType dispatcherType = DispatcherType.REQUEST;
211
212
213        // ---------------------------------------------------------------------
214        // HttpServletRequest properties
215        // ---------------------------------------------------------------------
216
217        private String authType;
218
219        private Cookie[] cookies;
220
221        private final Map<String, HeaderValueHolder> headers = new LinkedCaseInsensitiveMap<HeaderValueHolder>();
222
223        private String method;
224
225        private String pathInfo;
226
227        private String contextPath = "";
228
229        private String queryString;
230
231        private String remoteUser;
232
233        private final Set<String> userRoles = new HashSet<String>();
234
235        private Principal userPrincipal;
236
237        private String requestedSessionId;
238
239        private String requestURI;
240
241        private String servletPath = "";
242
243        private HttpSession session;
244
245        private boolean requestedSessionIdValid = true;
246
247        private boolean requestedSessionIdFromCookie = true;
248
249        private boolean requestedSessionIdFromURL = false;
250
251        private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<String, Part>();
252
253
254        // ---------------------------------------------------------------------
255        // Constructors
256        // ---------------------------------------------------------------------
257
258        /**
259         * Create a new {@code MockHttpServletRequest} with a default
260         * {@link MockServletContext}.
261         * @see #MockHttpServletRequest(ServletContext, String, String)
262         */
263        public MockHttpServletRequest() {
264                this(null, "", "");
265        }
266
267        /**
268         * Create a new {@code MockHttpServletRequest} with a default
269         * {@link MockServletContext}.
270         * @param method the request method (may be {@code null})
271         * @param requestURI the request URI (may be {@code null})
272         * @see #setMethod
273         * @see #setRequestURI
274         * @see #MockHttpServletRequest(ServletContext, String, String)
275         */
276        public MockHttpServletRequest(String method, String requestURI) {
277                this(null, method, requestURI);
278        }
279
280        /**
281         * Create a new {@code MockHttpServletRequest} with the supplied {@link ServletContext}.
282         * @param servletContext the ServletContext that the request runs in
283         * (may be {@code null} to use a default {@link MockServletContext})
284         * @see #MockHttpServletRequest(ServletContext, String, String)
285         */
286        public MockHttpServletRequest(ServletContext servletContext) {
287                this(servletContext, "", "");
288        }
289
290        /**
291         * Create a new {@code MockHttpServletRequest} with the supplied {@link ServletContext},
292         * {@code method}, and {@code requestURI}.
293         * <p>The preferred locale will be set to {@link Locale#ENGLISH}.
294         * @param servletContext the ServletContext that the request runs in (may be
295         * {@code null} to use a default {@link MockServletContext})
296         * @param method the request method (may be {@code null})
297         * @param requestURI the request URI (may be {@code null})
298         * @see #setMethod
299         * @see #setRequestURI
300         * @see #setPreferredLocales
301         * @see MockServletContext
302         */
303        public MockHttpServletRequest(ServletContext servletContext, String method, String requestURI) {
304                this.servletContext = (servletContext != null ? servletContext : new MockServletContext());
305                this.method = method;
306                this.requestURI = requestURI;
307                this.locales.add(Locale.ENGLISH);
308        }
309
310
311        // ---------------------------------------------------------------------
312        // Lifecycle methods
313        // ---------------------------------------------------------------------
314
315        /**
316         * Return the ServletContext that this request is associated with. (Not
317         * available in the standard HttpServletRequest interface for some reason.)
318         */
319        @Override
320        public ServletContext getServletContext() {
321                return this.servletContext;
322        }
323
324        /**
325         * Return whether this request is still active (that is, not completed yet).
326         */
327        public boolean isActive() {
328                return this.active;
329        }
330
331        /**
332         * Mark this request as completed, keeping its state.
333         */
334        public void close() {
335                this.active = false;
336        }
337
338        /**
339         * Invalidate this request, clearing its state.
340         */
341        public void invalidate() {
342                close();
343                clearAttributes();
344        }
345
346        /**
347         * Check whether this request is still active (that is, not completed yet),
348         * throwing an IllegalStateException if not active anymore.
349         */
350        protected void checkActive() throws IllegalStateException {
351                if (!this.active) {
352                        throw new IllegalStateException("Request is not active anymore");
353                }
354        }
355
356
357        // ---------------------------------------------------------------------
358        // ServletRequest interface
359        // ---------------------------------------------------------------------
360
361        @Override
362        public Object getAttribute(String name) {
363                checkActive();
364                return this.attributes.get(name);
365        }
366
367        @Override
368        public Enumeration<String> getAttributeNames() {
369                checkActive();
370                return Collections.enumeration(new LinkedHashSet<String>(this.attributes.keySet()));
371        }
372
373        @Override
374        public String getCharacterEncoding() {
375                return this.characterEncoding;
376        }
377
378        @Override
379        public void setCharacterEncoding(String characterEncoding) {
380                this.characterEncoding = characterEncoding;
381                updateContentTypeHeader();
382        }
383
384        private void updateContentTypeHeader() {
385                if (StringUtils.hasLength(this.contentType)) {
386                        StringBuilder sb = new StringBuilder(this.contentType);
387                        if (!this.contentType.toLowerCase().contains(CHARSET_PREFIX) &&
388                                        StringUtils.hasLength(this.characterEncoding)) {
389                                sb.append(";").append(CHARSET_PREFIX).append(this.characterEncoding);
390                        }
391                        doAddHeaderValue(CONTENT_TYPE_HEADER, sb.toString(), true);
392                }
393        }
394
395        public void setContent(byte[] content) {
396                this.content = content;
397        }
398
399        @Override
400        public int getContentLength() {
401                return (this.content != null ? this.content.length : -1);
402        }
403
404        public long getContentLengthLong() {
405                return getContentLength();
406        }
407
408        public void setContentType(String contentType) {
409                this.contentType = contentType;
410                if (contentType != null) {
411                        try {
412                                MediaType mediaType = MediaType.parseMediaType(contentType);
413                                if (mediaType.getCharset() != null) {
414                                        this.characterEncoding = mediaType.getCharset().name();
415                                }
416                        }
417                        catch (Exception ex) {
418                                // Try to get charset value anyway
419                                int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX);
420                                if (charsetIndex != -1) {
421                                        this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length());
422                                }
423                        }
424                        updateContentTypeHeader();
425                }
426        }
427
428        @Override
429        public String getContentType() {
430                return this.contentType;
431        }
432
433        @Override
434        public ServletInputStream getInputStream() {
435                if (this.content != null) {
436                        return new DelegatingServletInputStream(new ByteArrayInputStream(this.content));
437                }
438                else {
439                        return EMPTY_SERVLET_INPUT_STREAM;
440                }
441        }
442
443        /**
444         * Set a single value for the specified HTTP parameter.
445         * <p>If there are already one or more values registered for the given
446         * parameter name, they will be replaced.
447         */
448        public void setParameter(String name, String value) {
449                setParameter(name, new String[] {value});
450        }
451
452        /**
453         * Set an array of values for the specified HTTP parameter.
454         * <p>If there are already one or more values registered for the given
455         * parameter name, they will be replaced.
456         */
457        public void setParameter(String name, String... values) {
458                Assert.notNull(name, "Parameter name must not be null");
459                this.parameters.put(name, values);
460        }
461
462        /**
463         * Set all provided parameters <strong>replacing</strong> any existing
464         * values for the provided parameter names. To add without replacing
465         * existing values, use {@link #addParameters(java.util.Map)}.
466         */
467        public void setParameters(Map<String, ?> params) {
468                Assert.notNull(params, "Parameter map must not be null");
469                for (String key : params.keySet()) {
470                        Object value = params.get(key);
471                        if (value instanceof String) {
472                                setParameter(key, (String) value);
473                        }
474                        else if (value instanceof String[]) {
475                                setParameter(key, (String[]) value);
476                        }
477                        else {
478                                throw new IllegalArgumentException(
479                                                "Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]");
480                        }
481                }
482        }
483
484        /**
485         * Add a single value for the specified HTTP parameter.
486         * <p>If there are already one or more values registered for the given
487         * parameter name, the given value will be added to the end of the list.
488         */
489        public void addParameter(String name, String value) {
490                addParameter(name, new String[] {value});
491        }
492
493        /**
494         * Add an array of values for the specified HTTP parameter.
495         * <p>If there are already one or more values registered for the given
496         * parameter name, the given values will be added to the end of the list.
497         */
498        public void addParameter(String name, String... values) {
499                Assert.notNull(name, "Parameter name must not be null");
500                String[] oldArr = this.parameters.get(name);
501                if (oldArr != null) {
502                        String[] newArr = new String[oldArr.length + values.length];
503                        System.arraycopy(oldArr, 0, newArr, 0, oldArr.length);
504                        System.arraycopy(values, 0, newArr, oldArr.length, values.length);
505                        this.parameters.put(name, newArr);
506                }
507                else {
508                        this.parameters.put(name, values);
509                }
510        }
511
512        /**
513         * Add all provided parameters <strong>without</strong> replacing any
514         * existing values. To replace existing values, use
515         * {@link #setParameters(java.util.Map)}.
516         */
517        public void addParameters(Map<String, ?> params) {
518                Assert.notNull(params, "Parameter map must not be null");
519                for (String key : params.keySet()) {
520                        Object value = params.get(key);
521                        if (value instanceof String) {
522                                addParameter(key, (String) value);
523                        }
524                        else if (value instanceof String[]) {
525                                addParameter(key, (String[]) value);
526                        }
527                        else {
528                                throw new IllegalArgumentException("Parameter map value must be single value " +
529                                                " or array of type [" + String.class.getName() + "]");
530                        }
531                }
532        }
533
534        /**
535         * Remove already registered values for the specified HTTP parameter, if any.
536         */
537        public void removeParameter(String name) {
538                Assert.notNull(name, "Parameter name must not be null");
539                this.parameters.remove(name);
540        }
541
542        /**
543         * Remove all existing parameters.
544         */
545        public void removeAllParameters() {
546                this.parameters.clear();
547        }
548
549        @Override
550        public String getParameter(String name) {
551                String[] arr = (name != null ? this.parameters.get(name) : null);
552                return (arr != null && arr.length > 0 ? arr[0] : null);
553        }
554
555        @Override
556        public Enumeration<String> getParameterNames() {
557                return Collections.enumeration(this.parameters.keySet());
558        }
559
560        @Override
561        public String[] getParameterValues(String name) {
562                return (name != null ? this.parameters.get(name) : null);
563        }
564
565        @Override
566        public Map<String, String[]> getParameterMap() {
567                return Collections.unmodifiableMap(this.parameters);
568        }
569
570        public void setProtocol(String protocol) {
571                this.protocol = protocol;
572        }
573
574        @Override
575        public String getProtocol() {
576                return this.protocol;
577        }
578
579        public void setScheme(String scheme) {
580                this.scheme = scheme;
581        }
582
583        @Override
584        public String getScheme() {
585                return this.scheme;
586        }
587
588        public void setServerName(String serverName) {
589                this.serverName = serverName;
590        }
591
592        @Override
593        public String getServerName() {
594                String rawHostHeader = getHeader(HOST_HEADER);
595                String host = rawHostHeader;
596                if (host != null) {
597                        host = host.trim();
598                        if (host.startsWith("[")) {
599                                int indexOfClosingBracket = host.indexOf(']');
600                                Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader);
601                                host = host.substring(0, indexOfClosingBracket + 1);
602                        }
603                        else if (host.contains(":")) {
604                                host = host.substring(0, host.indexOf(':'));
605                        }
606                        return host;
607                }
608
609                // else
610                return this.serverName;
611        }
612
613        public void setServerPort(int serverPort) {
614                this.serverPort = serverPort;
615        }
616
617        @Override
618        public int getServerPort() {
619                String rawHostHeader = getHeader(HOST_HEADER);
620                String host = rawHostHeader;
621                if (host != null) {
622                        host = host.trim();
623                        int idx;
624                        if (host.startsWith("[")) {
625                                int indexOfClosingBracket = host.indexOf(']');
626                                Assert.state(indexOfClosingBracket > -1, "Invalid Host header: " + rawHostHeader);
627                                idx = host.indexOf(':', indexOfClosingBracket);
628                        }
629                        else {
630                                idx = host.indexOf(':');
631                        }
632                        if (idx != -1) {
633                                return Integer.parseInt(host.substring(idx + 1));
634                        }
635                }
636
637                // else
638                return this.serverPort;
639        }
640
641        @Override
642        public BufferedReader getReader() throws UnsupportedEncodingException {
643                if (this.content != null) {
644                        InputStream sourceStream = new ByteArrayInputStream(this.content);
645                        Reader sourceReader = (this.characterEncoding != null) ?
646                                        new InputStreamReader(sourceStream, this.characterEncoding) :
647                                        new InputStreamReader(sourceStream);
648                        return new BufferedReader(sourceReader);
649                }
650                else {
651                        return EMPTY_BUFFERED_READER;
652                }
653        }
654
655        public void setRemoteAddr(String remoteAddr) {
656                this.remoteAddr = remoteAddr;
657        }
658
659        @Override
660        public String getRemoteAddr() {
661                return this.remoteAddr;
662        }
663
664        public void setRemoteHost(String remoteHost) {
665                this.remoteHost = remoteHost;
666        }
667
668        @Override
669        public String getRemoteHost() {
670                return this.remoteHost;
671        }
672
673        @Override
674        public void setAttribute(String name, Object value) {
675                checkActive();
676                Assert.notNull(name, "Attribute name must not be null");
677                if (value != null) {
678                        this.attributes.put(name, value);
679                }
680                else {
681                        this.attributes.remove(name);
682                }
683        }
684
685        @Override
686        public void removeAttribute(String name) {
687                checkActive();
688                Assert.notNull(name, "Attribute name must not be null");
689                this.attributes.remove(name);
690        }
691
692        /**
693         * Clear all of this request's attributes.
694         */
695        public void clearAttributes() {
696                this.attributes.clear();
697        }
698
699        /**
700         * Add a new preferred locale, before any existing locales.
701         * @see #setPreferredLocales
702         */
703        public void addPreferredLocale(Locale locale) {
704                Assert.notNull(locale, "Locale must not be null");
705                this.locales.add(0, locale);
706        }
707
708        /**
709         * Set the list of preferred locales, in descending order, effectively replacing
710         * any existing locales.
711         * @since 3.2
712         * @see #addPreferredLocale
713         */
714        public void setPreferredLocales(List<Locale> locales) {
715                Assert.notEmpty(locales, "Locale list must not be empty");
716                this.locales.clear();
717                this.locales.addAll(locales);
718        }
719
720        /**
721         * Return the first preferred {@linkplain Locale locale} configured
722         * in this mock request.
723         * <p>If no locales have been explicitly configured, the default,
724         * preferred {@link Locale} for the <em>server</em> mocked by this
725         * request is {@link Locale#ENGLISH}.
726         * <p>In contrast to the Servlet specification, this mock implementation
727         * does <strong>not</strong> take into consideration any locales
728         * specified via the {@code Accept-Language} header.
729         * @see javax.servlet.ServletRequest#getLocale()
730         * @see #addPreferredLocale(Locale)
731         * @see #setPreferredLocales(List)
732         */
733        @Override
734        public Locale getLocale() {
735                return this.locales.get(0);
736        }
737
738        /**
739         * Return an {@linkplain Enumeration enumeration} of the preferred
740         * {@linkplain Locale locales} configured in this mock request.
741         * <p>If no locales have been explicitly configured, the default,
742         * preferred {@link Locale} for the <em>server</em> mocked by this
743         * request is {@link Locale#ENGLISH}.
744         * <p>In contrast to the Servlet specification, this mock implementation
745         * does <strong>not</strong> take into consideration any locales
746         * specified via the {@code Accept-Language} header.
747         * @see javax.servlet.ServletRequest#getLocales()
748         * @see #addPreferredLocale(Locale)
749         * @see #setPreferredLocales(List)
750         */
751        @Override
752        public Enumeration<Locale> getLocales() {
753                return Collections.enumeration(this.locales);
754        }
755
756        /**
757         * Set the boolean {@code secure} flag indicating whether the mock request
758         * was made using a secure channel, such as HTTPS.
759         * @see #isSecure()
760         * @see #getScheme()
761         * @see #setScheme(String)
762         */
763        public void setSecure(boolean secure) {
764                this.secure = secure;
765        }
766
767        /**
768         * Return {@code true} if the {@link #setSecure secure} flag has been set
769         * to {@code true} or if the {@link #getScheme scheme} is {@code https}.
770         * @see javax.servlet.ServletRequest#isSecure()
771         */
772        @Override
773        public boolean isSecure() {
774                return (this.secure || HTTPS.equalsIgnoreCase(this.scheme));
775        }
776
777        @Override
778        public RequestDispatcher getRequestDispatcher(String path) {
779                return new MockRequestDispatcher(path);
780        }
781
782        @Override
783        @Deprecated
784        public String getRealPath(String path) {
785                return this.servletContext.getRealPath(path);
786        }
787
788        public void setRemotePort(int remotePort) {
789                this.remotePort = remotePort;
790        }
791
792        @Override
793        public int getRemotePort() {
794                return this.remotePort;
795        }
796
797        public void setLocalName(String localName) {
798                this.localName = localName;
799        }
800
801        @Override
802        public String getLocalName() {
803                return this.localName;
804        }
805
806        public void setLocalAddr(String localAddr) {
807                this.localAddr = localAddr;
808        }
809
810        @Override
811        public String getLocalAddr() {
812                return this.localAddr;
813        }
814
815        public void setLocalPort(int localPort) {
816                this.localPort = localPort;
817        }
818
819        @Override
820        public int getLocalPort() {
821                return this.localPort;
822        }
823
824        @Override
825        public AsyncContext startAsync() {
826                return startAsync(this, null);
827        }
828
829        @Override
830        public AsyncContext startAsync(ServletRequest request, ServletResponse response) {
831                if (!this.asyncSupported) {
832                        throw new IllegalStateException("Async not supported");
833                }
834                this.asyncStarted = true;
835                this.asyncContext = new MockAsyncContext(request, response);
836                return this.asyncContext;
837        }
838
839        public void setAsyncStarted(boolean asyncStarted) {
840                this.asyncStarted = asyncStarted;
841        }
842
843        @Override
844        public boolean isAsyncStarted() {
845                return this.asyncStarted;
846        }
847
848        public void setAsyncSupported(boolean asyncSupported) {
849                this.asyncSupported = asyncSupported;
850        }
851
852        @Override
853        public boolean isAsyncSupported() {
854                return this.asyncSupported;
855        }
856
857        public void setAsyncContext(MockAsyncContext asyncContext) {
858                this.asyncContext = asyncContext;
859        }
860
861        @Override
862        public AsyncContext getAsyncContext() {
863                return this.asyncContext;
864        }
865
866        public void setDispatcherType(DispatcherType dispatcherType) {
867                this.dispatcherType = dispatcherType;
868        }
869
870        @Override
871        public DispatcherType getDispatcherType() {
872                return this.dispatcherType;
873        }
874
875
876        // ---------------------------------------------------------------------
877        // HttpServletRequest interface
878        // ---------------------------------------------------------------------
879
880        public void setAuthType(String authType) {
881                this.authType = authType;
882        }
883
884        @Override
885        public String getAuthType() {
886                return this.authType;
887        }
888
889        public void setCookies(Cookie... cookies) {
890                this.cookies = cookies;
891        }
892
893        @Override
894        public Cookie[] getCookies() {
895                return this.cookies;
896        }
897
898        /**
899         * Add an HTTP header entry for the given name.
900         * <p>While this method can take any {@code Object} as a parameter,
901         * it is recommended to use the following types:
902         * <ul>
903         * <li>String or any Object to be converted using {@code toString()}; see {@link #getHeader}.</li>
904         * <li>String, Number, or Date for date headers; see {@link #getDateHeader}.</li>
905         * <li>String or Number for integer headers; see {@link #getIntHeader}.</li>
906         * <li>{@code String[]} or {@code Collection<String>} for multiple values; see {@link #getHeaders}.</li>
907         * </ul>
908         * @see #getHeaderNames
909         * @see #getHeaders
910         * @see #getHeader
911         * @see #getDateHeader
912         */
913        public void addHeader(String name, Object value) {
914                if (CONTENT_TYPE_HEADER.equalsIgnoreCase(name) && !this.headers.containsKey(CONTENT_TYPE_HEADER)) {
915                        setContentType(value.toString());
916                }
917                else {
918                        doAddHeaderValue(name, value, false);
919                }
920        }
921
922        private void doAddHeaderValue(String name, Object value, boolean replace) {
923                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
924                Assert.notNull(value, "Header value must not be null");
925                if (header == null || replace) {
926                        header = new HeaderValueHolder();
927                        this.headers.put(name, header);
928                }
929                if (value instanceof Collection) {
930                        header.addValues((Collection<?>) value);
931                }
932                else if (value.getClass().isArray()) {
933                        header.addValueArray(value);
934                }
935                else {
936                        header.addValue(value);
937                }
938        }
939
940        /**
941         * Remove already registered entries for the specified HTTP header, if any.
942         * @since 4.3.20
943         */
944        public void removeHeader(String name) {
945                Assert.notNull(name, "Header name must not be null");
946                this.headers.remove(name);
947        }
948
949        /**
950         * Return the long timestamp for the date header with the given {@code name}.
951         * <p>If the internal value representation is a String, this method will try
952         * to parse it as a date using the supported date formats:
953         * <ul>
954         * <li>"EEE, dd MMM yyyy HH:mm:ss zzz"</li>
955         * <li>"EEE, dd-MMM-yy HH:mm:ss zzz"</li>
956         * <li>"EEE MMM dd HH:mm:ss yyyy"</li>
957         * </ul>
958         * @param name the header name
959         * @see <a href="https://tools.ietf.org/html/rfc7231#section-7.1.1.1">Section 7.1.1.1 of RFC 7231</a>
960         */
961        @Override
962        public long getDateHeader(String name) {
963                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
964                Object value = (header != null ? header.getValue() : null);
965                if (value instanceof Date) {
966                        return ((Date) value).getTime();
967                }
968                else if (value instanceof Number) {
969                        return ((Number) value).longValue();
970                }
971                else if (value instanceof String) {
972                        return parseDateHeader(name, (String) value);
973                }
974                else if (value != null) {
975                        throw new IllegalArgumentException(
976                                        "Value for header '" + name + "' is not a Date, Number, or String: " + value);
977                }
978                else {
979                        return -1L;
980                }
981        }
982
983        private long parseDateHeader(String name, String value) {
984                for (String dateFormat : DATE_FORMATS) {
985                        SimpleDateFormat simpleDateFormat = new SimpleDateFormat(dateFormat, Locale.US);
986                        simpleDateFormat.setTimeZone(GMT);
987                        try {
988                                return simpleDateFormat.parse(value).getTime();
989                        }
990                        catch (ParseException ex) {
991                                // ignore
992                        }
993                }
994                throw new IllegalArgumentException("Cannot parse date value '" + value + "' for '" + name + "' header");
995        }
996
997        @Override
998        public String getHeader(String name) {
999                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
1000                return (header != null ? header.getStringValue() : null);
1001        }
1002
1003        @Override
1004        public Enumeration<String> getHeaders(String name) {
1005                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
1006                return Collections.enumeration(header != null ? header.getStringValues() : new LinkedList<String>());
1007        }
1008
1009        @Override
1010        public Enumeration<String> getHeaderNames() {
1011                return Collections.enumeration(this.headers.keySet());
1012        }
1013
1014        @Override
1015        public int getIntHeader(String name) {
1016                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
1017                Object value = (header != null ? header.getValue() : null);
1018                if (value instanceof Number) {
1019                        return ((Number) value).intValue();
1020                }
1021                else if (value instanceof String) {
1022                        return Integer.parseInt((String) value);
1023                }
1024                else if (value != null) {
1025                        throw new NumberFormatException("Value for header '" + name + "' is not a Number: " + value);
1026                }
1027                else {
1028                        return -1;
1029                }
1030        }
1031
1032        public void setMethod(String method) {
1033                this.method = method;
1034        }
1035
1036        @Override
1037        public String getMethod() {
1038                return this.method;
1039        }
1040
1041        public void setPathInfo(String pathInfo) {
1042                this.pathInfo = pathInfo;
1043        }
1044
1045        @Override
1046        public String getPathInfo() {
1047                return this.pathInfo;
1048        }
1049
1050        @Override
1051        public String getPathTranslated() {
1052                return (this.pathInfo != null ? getRealPath(this.pathInfo) : null);
1053        }
1054
1055        public void setContextPath(String contextPath) {
1056                this.contextPath = contextPath;
1057        }
1058
1059        @Override
1060        public String getContextPath() {
1061                return this.contextPath;
1062        }
1063
1064        public void setQueryString(String queryString) {
1065                this.queryString = queryString;
1066        }
1067
1068        @Override
1069        public String getQueryString() {
1070                return this.queryString;
1071        }
1072
1073        public void setRemoteUser(String remoteUser) {
1074                this.remoteUser = remoteUser;
1075        }
1076
1077        @Override
1078        public String getRemoteUser() {
1079                return this.remoteUser;
1080        }
1081
1082        public void addUserRole(String role) {
1083                this.userRoles.add(role);
1084        }
1085
1086        @Override
1087        public boolean isUserInRole(String role) {
1088                return (this.userRoles.contains(role) || (this.servletContext instanceof MockServletContext &&
1089                                ((MockServletContext) this.servletContext).getDeclaredRoles().contains(role)));
1090        }
1091
1092        public void setUserPrincipal(Principal userPrincipal) {
1093                this.userPrincipal = userPrincipal;
1094        }
1095
1096        @Override
1097        public Principal getUserPrincipal() {
1098                return this.userPrincipal;
1099        }
1100
1101        public void setRequestedSessionId(String requestedSessionId) {
1102                this.requestedSessionId = requestedSessionId;
1103        }
1104
1105        @Override
1106        public String getRequestedSessionId() {
1107                return this.requestedSessionId;
1108        }
1109
1110        public void setRequestURI(String requestURI) {
1111                this.requestURI = requestURI;
1112        }
1113
1114        @Override
1115        public String getRequestURI() {
1116                return this.requestURI;
1117        }
1118
1119        @Override
1120        public StringBuffer getRequestURL() {
1121                String scheme = getScheme();
1122                String server = getServerName();
1123                int port = getServerPort();
1124                String uri = getRequestURI();
1125
1126                StringBuffer url = new StringBuffer(scheme).append("://").append(server);
1127                if (port > 0 && ((HTTP.equalsIgnoreCase(scheme) && port != 80) ||
1128                                (HTTPS.equalsIgnoreCase(scheme) && port != 443))) {
1129                        url.append(':').append(port);
1130                }
1131                if (StringUtils.hasText(uri)) {
1132                        url.append(uri);
1133                }
1134                return url;
1135        }
1136
1137        public void setServletPath(String servletPath) {
1138                this.servletPath = servletPath;
1139        }
1140
1141        @Override
1142        public String getServletPath() {
1143                return this.servletPath;
1144        }
1145
1146        public void setSession(HttpSession session) {
1147                this.session = session;
1148                if (session instanceof MockHttpSession) {
1149                        MockHttpSession mockSession = ((MockHttpSession) session);
1150                        mockSession.access();
1151                }
1152        }
1153
1154        @Override
1155        public HttpSession getSession(boolean create) {
1156                checkActive();
1157                // Reset session if invalidated.
1158                if (this.session instanceof MockHttpSession && ((MockHttpSession) this.session).isInvalid()) {
1159                        this.session = null;
1160                }
1161                // Create new session if necessary.
1162                if (this.session == null && create) {
1163                        this.session = new MockHttpSession(this.servletContext);
1164                }
1165                return this.session;
1166        }
1167
1168        @Override
1169        public HttpSession getSession() {
1170                return getSession(true);
1171        }
1172
1173        /**
1174         * The implementation of this (Servlet 3.1+) method calls
1175         * {@link MockHttpSession#changeSessionId()} if the session is a mock session.
1176         * Otherwise it simply returns the current session id.
1177         * @since 4.0.3
1178         */
1179        public String changeSessionId() {
1180                Assert.isTrue(this.session != null, "The request does not have a session");
1181                if (this.session instanceof MockHttpSession) {
1182                        return ((MockHttpSession) this.session).changeSessionId();
1183                }
1184                return this.session.getId();
1185        }
1186
1187        public void setRequestedSessionIdValid(boolean requestedSessionIdValid) {
1188                this.requestedSessionIdValid = requestedSessionIdValid;
1189        }
1190
1191        @Override
1192        public boolean isRequestedSessionIdValid() {
1193                return this.requestedSessionIdValid;
1194        }
1195
1196        public void setRequestedSessionIdFromCookie(boolean requestedSessionIdFromCookie) {
1197                this.requestedSessionIdFromCookie = requestedSessionIdFromCookie;
1198        }
1199
1200        @Override
1201        public boolean isRequestedSessionIdFromCookie() {
1202                return this.requestedSessionIdFromCookie;
1203        }
1204
1205        public void setRequestedSessionIdFromURL(boolean requestedSessionIdFromURL) {
1206                this.requestedSessionIdFromURL = requestedSessionIdFromURL;
1207        }
1208
1209        @Override
1210        public boolean isRequestedSessionIdFromURL() {
1211                return this.requestedSessionIdFromURL;
1212        }
1213
1214        @Override
1215        @Deprecated
1216        public boolean isRequestedSessionIdFromUrl() {
1217                return isRequestedSessionIdFromURL();
1218        }
1219
1220        @Override
1221        public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
1222                throw new UnsupportedOperationException();
1223        }
1224
1225        @Override
1226        public void login(String username, String password) throws ServletException {
1227                throw new UnsupportedOperationException();
1228        }
1229
1230        @Override
1231        public void logout() throws ServletException {
1232                this.userPrincipal = null;
1233                this.remoteUser = null;
1234                this.authType = null;
1235        }
1236
1237        public void addPart(Part part) {
1238                this.parts.add(part.getName(), part);
1239        }
1240
1241        @Override
1242        public Part getPart(String name) throws IOException, ServletException {
1243                return this.parts.getFirst(name);
1244        }
1245
1246        @Override
1247        public Collection<Part> getParts() throws IOException, ServletException {
1248                List<Part> result = new LinkedList<Part>();
1249                for (List<Part> list : this.parts.values()) {
1250                        result.addAll(list);
1251                }
1252                return result;
1253        }
1254
1255}