001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.ByteArrayOutputStream;
020import java.io.IOException;
021import java.io.OutputStream;
022import java.io.OutputStreamWriter;
023import java.io.PrintWriter;
024import java.io.UnsupportedEncodingException;
025import java.io.Writer;
026import java.text.DateFormat;
027import java.text.ParseException;
028import java.text.SimpleDateFormat;
029import java.util.ArrayList;
030import java.util.Collection;
031import java.util.Collections;
032import java.util.Date;
033import java.util.List;
034import java.util.Locale;
035import java.util.Map;
036import java.util.TimeZone;
037import javax.servlet.ServletOutputStream;
038import javax.servlet.http.Cookie;
039import javax.servlet.http.HttpServletResponse;
040
041import org.springframework.http.MediaType;
042import org.springframework.util.Assert;
043import org.springframework.util.LinkedCaseInsensitiveMap;
044import org.springframework.web.util.WebUtils;
045
046/**
047 * Mock implementation of the {@link javax.servlet.http.HttpServletResponse} interface.
048 *
049 * <p>As of Spring 4.0, this set of mocks is designed on a Servlet 3.0 baseline.
050 * Beyond that, {@code MockHttpServletResponse} is also compatible with Servlet
051 * 3.1's {@code setContentLengthLong()} method.
052 *
053 * @author Juergen Hoeller
054 * @author Rod Johnson
055 * @author Brian Clozel
056 * @since 1.0.2
057 */
058public class MockHttpServletResponse implements HttpServletResponse {
059
060        private static final String CHARSET_PREFIX = "charset=";
061
062        private static final String CONTENT_TYPE_HEADER = "Content-Type";
063
064        private static final String CONTENT_LENGTH_HEADER = "Content-Length";
065
066        private static final String LOCATION_HEADER = "Location";
067
068        private static final String DATE_FORMAT = "EEE, dd MMM yyyy HH:mm:ss zzz";
069
070        private static final TimeZone GMT = TimeZone.getTimeZone("GMT");
071
072
073        //---------------------------------------------------------------------
074        // ServletResponse properties
075        //---------------------------------------------------------------------
076
077        private boolean outputStreamAccessAllowed = true;
078
079        private boolean writerAccessAllowed = true;
080
081        private String characterEncoding = WebUtils.DEFAULT_CHARACTER_ENCODING;
082
083        private boolean charset = false;
084
085        private final ByteArrayOutputStream content = new ByteArrayOutputStream(1024);
086
087        private final ServletOutputStream outputStream = new ResponseServletOutputStream(this.content);
088
089        private PrintWriter writer;
090
091        private long contentLength = 0;
092
093        private String contentType;
094
095        private int bufferSize = 4096;
096
097        private boolean committed;
098
099        private Locale locale = Locale.getDefault();
100
101
102        //---------------------------------------------------------------------
103        // HttpServletResponse properties
104        //---------------------------------------------------------------------
105
106        private final List<Cookie> cookies = new ArrayList<Cookie>();
107
108        private final Map<String, HeaderValueHolder> headers = new LinkedCaseInsensitiveMap<HeaderValueHolder>();
109
110        private int status = HttpServletResponse.SC_OK;
111
112        private String errorMessage;
113
114        private String forwardedUrl;
115
116        private final List<String> includedUrls = new ArrayList<String>();
117
118
119        //---------------------------------------------------------------------
120        // ServletResponse interface
121        //---------------------------------------------------------------------
122
123        /**
124         * Set whether {@link #getOutputStream()} access is allowed.
125         * <p>Default is {@code true}.
126         */
127        public void setOutputStreamAccessAllowed(boolean outputStreamAccessAllowed) {
128                this.outputStreamAccessAllowed = outputStreamAccessAllowed;
129        }
130
131        /**
132         * Return whether {@link #getOutputStream()} access is allowed.
133         */
134        public boolean isOutputStreamAccessAllowed() {
135                return this.outputStreamAccessAllowed;
136        }
137
138        /**
139         * Set whether {@link #getWriter()} access is allowed.
140         * <p>Default is {@code true}.
141         */
142        public void setWriterAccessAllowed(boolean writerAccessAllowed) {
143                this.writerAccessAllowed = writerAccessAllowed;
144        }
145
146        /**
147         * Return whether {@link #getOutputStream()} access is allowed.
148         */
149        public boolean isWriterAccessAllowed() {
150                return this.writerAccessAllowed;
151        }
152
153        /**
154         * Return whether the character encoding has been set.
155         * <p>If {@code false}, {@link #getCharacterEncoding()} will return a default encoding value.
156         */
157        public boolean isCharset() {
158                return this.charset;
159        }
160
161        @Override
162        public void setCharacterEncoding(String characterEncoding) {
163                this.characterEncoding = characterEncoding;
164                this.charset = true;
165                updateContentTypeHeader();
166        }
167
168        private void updateContentTypeHeader() {
169                if (this.contentType != null) {
170                        String value = this.contentType;
171                        if (this.charset && !this.contentType.toLowerCase().contains(CHARSET_PREFIX)) {
172                                value = value + ';' + CHARSET_PREFIX + this.characterEncoding;
173                        }
174                        doAddHeaderValue(CONTENT_TYPE_HEADER, value, true);
175                }
176        }
177
178        @Override
179        public String getCharacterEncoding() {
180                return this.characterEncoding;
181        }
182
183        @Override
184        public ServletOutputStream getOutputStream() {
185                if (!this.outputStreamAccessAllowed) {
186                        throw new IllegalStateException("OutputStream access not allowed");
187                }
188                return this.outputStream;
189        }
190
191        @Override
192        public PrintWriter getWriter() throws UnsupportedEncodingException {
193                if (!this.writerAccessAllowed) {
194                        throw new IllegalStateException("Writer access not allowed");
195                }
196                if (this.writer == null) {
197                        Writer targetWriter = (this.characterEncoding != null ?
198                                        new OutputStreamWriter(this.content, this.characterEncoding) :
199                                        new OutputStreamWriter(this.content));
200                        this.writer = new ResponsePrintWriter(targetWriter);
201                }
202                return this.writer;
203        }
204
205        public byte[] getContentAsByteArray() {
206                flushBuffer();
207                return this.content.toByteArray();
208        }
209
210        public String getContentAsString() throws UnsupportedEncodingException {
211                flushBuffer();
212                return (this.characterEncoding != null ?
213                                this.content.toString(this.characterEncoding) : this.content.toString());
214        }
215
216        @Override
217        public void setContentLength(int contentLength) {
218                this.contentLength = contentLength;
219                doAddHeaderValue(CONTENT_LENGTH_HEADER, contentLength, true);
220        }
221
222        public int getContentLength() {
223                return (int) this.contentLength;
224        }
225
226        public void setContentLengthLong(long contentLength) {
227                this.contentLength = contentLength;
228                doAddHeaderValue(CONTENT_LENGTH_HEADER, contentLength, true);
229        }
230
231        public long getContentLengthLong() {
232                return this.contentLength;
233        }
234
235        @Override
236        public void setContentType(String contentType) {
237                this.contentType = contentType;
238                if (contentType != null) {
239                        try {
240                                MediaType mediaType = MediaType.parseMediaType(contentType);
241                                if (mediaType.getCharset() != null) {
242                                        this.characterEncoding = mediaType.getCharset().name();
243                                        this.charset = true;
244                                }
245                        }
246                        catch (Exception ex) {
247                                // Try to get charset value anyway
248                                int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX);
249                                if (charsetIndex != -1) {
250                                        this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length());
251                                        this.charset = true;
252                                }
253                        }
254                        updateContentTypeHeader();
255                }
256        }
257
258        @Override
259        public String getContentType() {
260                return this.contentType;
261        }
262
263        @Override
264        public void setBufferSize(int bufferSize) {
265                this.bufferSize = bufferSize;
266        }
267
268        @Override
269        public int getBufferSize() {
270                return this.bufferSize;
271        }
272
273        @Override
274        public void flushBuffer() {
275                setCommitted(true);
276        }
277
278        @Override
279        public void resetBuffer() {
280                if (isCommitted()) {
281                        throw new IllegalStateException("Cannot reset buffer - response is already committed");
282                }
283                this.content.reset();
284        }
285
286        private void setCommittedIfBufferSizeExceeded() {
287                int bufSize = getBufferSize();
288                if (bufSize > 0 && this.content.size() > bufSize) {
289                        setCommitted(true);
290                }
291        }
292
293        public void setCommitted(boolean committed) {
294                this.committed = committed;
295        }
296
297        @Override
298        public boolean isCommitted() {
299                return this.committed;
300        }
301
302        @Override
303        public void reset() {
304                resetBuffer();
305                this.characterEncoding = null;
306                this.charset = false;
307                this.contentLength = 0;
308                this.contentType = null;
309                this.locale = null;
310                this.cookies.clear();
311                this.headers.clear();
312                this.status = HttpServletResponse.SC_OK;
313                this.errorMessage = null;
314        }
315
316        @Override
317        public void setLocale(Locale locale) {
318                this.locale = locale;
319        }
320
321        @Override
322        public Locale getLocale() {
323                return this.locale;
324        }
325
326
327        //---------------------------------------------------------------------
328        // HttpServletResponse interface
329        //---------------------------------------------------------------------
330
331        @Override
332        public void addCookie(Cookie cookie) {
333                Assert.notNull(cookie, "Cookie must not be null");
334                this.cookies.add(cookie);
335        }
336
337        public Cookie[] getCookies() {
338                return this.cookies.toArray(new Cookie[this.cookies.size()]);
339        }
340
341        public Cookie getCookie(String name) {
342                Assert.notNull(name, "Cookie name must not be null");
343                for (Cookie cookie : this.cookies) {
344                        if (name.equals(cookie.getName())) {
345                                return cookie;
346                        }
347                }
348                return null;
349        }
350
351        @Override
352        public boolean containsHeader(String name) {
353                return (HeaderValueHolder.getByName(this.headers, name) != null);
354        }
355
356        /**
357         * Return the names of all specified headers as a Set of Strings.
358         * <p>As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}.
359         * @return the {@code Set} of header name {@code Strings}, or an empty {@code Set} if none
360         */
361        @Override
362        public Collection<String> getHeaderNames() {
363                return this.headers.keySet();
364        }
365
366        /**
367         * Return the primary value for the given header as a String, if any.
368         * Will return the first value in case of multiple values.
369         * <p>As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}.
370         * As of Spring 3.1, it returns a stringified value for Servlet 3.0 compatibility.
371         * Consider using {@link #getHeaderValue(String)} for raw Object access.
372         * @param name the name of the header
373         * @return the associated header value, or {@code null} if none
374         */
375        @Override
376        public String getHeader(String name) {
377                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
378                return (header != null ? header.getStringValue() : null);
379        }
380
381        /**
382         * Return all values for the given header as a List of Strings.
383         * <p>As of Servlet 3.0, this method is also defined in {@link HttpServletResponse}.
384         * As of Spring 3.1, it returns a List of stringified values for Servlet 3.0 compatibility.
385         * Consider using {@link #getHeaderValues(String)} for raw Object access.
386         * @param name the name of the header
387         * @return the associated header values, or an empty List if none
388         */
389        @Override
390        public List<String> getHeaders(String name) {
391                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
392                if (header != null) {
393                        return header.getStringValues();
394                }
395                else {
396                        return Collections.emptyList();
397                }
398        }
399
400        /**
401         * Return the primary value for the given header, if any.
402         * <p>Will return the first value in case of multiple values.
403         * @param name the name of the header
404         * @return the associated header value, or {@code null} if none
405         */
406        public Object getHeaderValue(String name) {
407                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
408                return (header != null ? header.getValue() : null);
409        }
410
411        /**
412         * Return all values for the given header as a List of value objects.
413         * @param name the name of the header
414         * @return the associated header values, or an empty List if none
415         */
416        public List<Object> getHeaderValues(String name) {
417                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
418                if (header != null) {
419                        return header.getValues();
420                }
421                else {
422                        return Collections.emptyList();
423                }
424        }
425
426        /**
427         * The default implementation returns the given URL String as-is.
428         * <p>Can be overridden in subclasses, appending a session id or the like.
429         */
430        @Override
431        public String encodeURL(String url) {
432                return url;
433        }
434
435        /**
436         * The default implementation delegates to {@link #encodeURL},
437         * returning the given URL String as-is.
438         * <p>Can be overridden in subclasses, appending a session id or the like
439         * in a redirect-specific fashion. For general URL encoding rules,
440         * override the common {@link #encodeURL} method instead, applying
441         * to redirect URLs as well as to general URLs.
442         */
443        @Override
444        public String encodeRedirectURL(String url) {
445                return encodeURL(url);
446        }
447
448        @Override
449        @Deprecated
450        public String encodeUrl(String url) {
451                return encodeURL(url);
452        }
453
454        @Override
455        @Deprecated
456        public String encodeRedirectUrl(String url) {
457                return encodeRedirectURL(url);
458        }
459
460        @Override
461        public void sendError(int status, String errorMessage) throws IOException {
462                if (isCommitted()) {
463                        throw new IllegalStateException("Cannot set error status - response is already committed");
464                }
465                this.status = status;
466                this.errorMessage = errorMessage;
467                setCommitted(true);
468        }
469
470        @Override
471        public void sendError(int status) throws IOException {
472                if (isCommitted()) {
473                        throw new IllegalStateException("Cannot set error status - response is already committed");
474                }
475                this.status = status;
476                setCommitted(true);
477        }
478
479        @Override
480        public void sendRedirect(String url) throws IOException {
481                if (isCommitted()) {
482                        throw new IllegalStateException("Cannot send redirect - response is already committed");
483                }
484                Assert.notNull(url, "Redirect URL must not be null");
485                setHeader(LOCATION_HEADER, url);
486                setStatus(HttpServletResponse.SC_MOVED_TEMPORARILY);
487                setCommitted(true);
488        }
489
490        public String getRedirectedUrl() {
491                return getHeader(LOCATION_HEADER);
492        }
493
494        @Override
495        public void setDateHeader(String name, long value) {
496                setHeaderValue(name, formatDate(value));
497        }
498
499        @Override
500        public void addDateHeader(String name, long value) {
501                addHeaderValue(name, formatDate(value));
502        }
503
504        public long getDateHeader(String name) {
505                String headerValue = getHeader(name);
506                if (headerValue == null) {
507                        return -1;
508                }
509                try {
510                        return newDateFormat().parse(getHeader(name)).getTime();
511                }
512                catch (ParseException ex) {
513                        throw new IllegalArgumentException(
514                                        "Value for header '" + name + "' is not a valid Date: " + headerValue);
515                }
516        }
517
518        private String formatDate(long date) {
519                return newDateFormat().format(new Date(date));
520        }
521
522        private DateFormat newDateFormat() {
523                SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT, Locale.US);
524                dateFormat.setTimeZone(GMT);
525                return dateFormat;
526        }
527
528        @Override
529        public void setHeader(String name, String value) {
530                setHeaderValue(name, value);
531        }
532
533        @Override
534        public void addHeader(String name, String value) {
535                addHeaderValue(name, value);
536        }
537
538        @Override
539        public void setIntHeader(String name, int value) {
540                setHeaderValue(name, value);
541        }
542
543        @Override
544        public void addIntHeader(String name, int value) {
545                addHeaderValue(name, value);
546        }
547
548        private void setHeaderValue(String name, Object value) {
549                if (setSpecialHeader(name, value)) {
550                        return;
551                }
552                doAddHeaderValue(name, value, true);
553        }
554
555        private void addHeaderValue(String name, Object value) {
556                if (setSpecialHeader(name, value)) {
557                        return;
558                }
559                doAddHeaderValue(name, value, false);
560        }
561
562        private boolean setSpecialHeader(String name, Object value) {
563                if (CONTENT_TYPE_HEADER.equalsIgnoreCase(name)) {
564                        setContentType(value.toString());
565                        return true;
566                }
567                else if (CONTENT_LENGTH_HEADER.equalsIgnoreCase(name)) {
568                        setContentLength(value instanceof Number ? ((Number) value).intValue() :
569                                        Integer.parseInt(value.toString()));
570                        return true;
571                }
572                else {
573                        return false;
574                }
575        }
576
577        private void doAddHeaderValue(String name, Object value, boolean replace) {
578                HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
579                Assert.notNull(value, "Header value must not be null");
580                if (header == null) {
581                        header = new HeaderValueHolder();
582                        this.headers.put(name, header);
583                }
584                if (replace) {
585                        header.setValue(value);
586                }
587                else {
588                        header.addValue(value);
589                }
590        }
591
592        @Override
593        public void setStatus(int status) {
594                if (!this.isCommitted()) {
595                        this.status = status;
596                }
597        }
598
599        @Override
600        @Deprecated
601        public void setStatus(int status, String errorMessage) {
602                if (!this.isCommitted()) {
603                        this.status = status;
604                        this.errorMessage = errorMessage;
605                }
606        }
607
608        @Override
609        public int getStatus() {
610                return this.status;
611        }
612
613        public String getErrorMessage() {
614                return this.errorMessage;
615        }
616
617
618        //---------------------------------------------------------------------
619        // Methods for MockRequestDispatcher
620        //---------------------------------------------------------------------
621
622        public void setForwardedUrl(String forwardedUrl) {
623                this.forwardedUrl = forwardedUrl;
624        }
625
626        public String getForwardedUrl() {
627                return this.forwardedUrl;
628        }
629
630        public void setIncludedUrl(String includedUrl) {
631                this.includedUrls.clear();
632                if (includedUrl != null) {
633                        this.includedUrls.add(includedUrl);
634                }
635        }
636
637        public String getIncludedUrl() {
638                int count = this.includedUrls.size();
639                if (count > 1) {
640                        throw new IllegalStateException(
641                                        "More than 1 URL included - check getIncludedUrls instead: " + this.includedUrls);
642                }
643                return (count == 1 ? this.includedUrls.get(0) : null);
644        }
645
646        public void addIncludedUrl(String includedUrl) {
647                Assert.notNull(includedUrl, "Included URL must not be null");
648                this.includedUrls.add(includedUrl);
649        }
650
651        public List<String> getIncludedUrls() {
652                return this.includedUrls;
653        }
654
655
656        /**
657         * Inner class that adapts the ServletOutputStream to mark the
658         * response as committed once the buffer size is exceeded.
659         */
660        private class ResponseServletOutputStream extends DelegatingServletOutputStream {
661
662                public ResponseServletOutputStream(OutputStream out) {
663                        super(out);
664                }
665
666                @Override
667                public void write(int b) throws IOException {
668                        super.write(b);
669                        super.flush();
670                        setCommittedIfBufferSizeExceeded();
671                }
672
673                @Override
674                public void flush() throws IOException {
675                        super.flush();
676                        setCommitted(true);
677                }
678        }
679
680
681        /**
682         * Inner class that adapts the PrintWriter to mark the
683         * response as committed once the buffer size is exceeded.
684         */
685        private class ResponsePrintWriter extends PrintWriter {
686
687                public ResponsePrintWriter(Writer out) {
688                        super(out, true);
689                }
690
691                @Override
692                public void write(char[] buf, int off, int len) {
693                        super.write(buf, off, len);
694                        super.flush();
695                        setCommittedIfBufferSizeExceeded();
696                }
697
698                @Override
699                public void write(String s, int off, int len) {
700                        super.write(s, off, len);
701                        super.flush();
702                        setCommittedIfBufferSizeExceeded();
703                }
704
705                @Override
706                public void write(int c) {
707                        super.write(c);
708                        super.flush();
709                        setCommittedIfBufferSizeExceeded();
710                }
711
712                @Override
713                public void flush() {
714                        super.flush();
715                        setCommitted(true);
716                }
717        }
718
719}