001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.server;
018
019import java.io.IOException;
020import java.io.OutputStream;
021import java.util.ArrayList;
022import java.util.Collection;
023import java.util.List;
024
025import javax.servlet.http.HttpServletResponse;
026
027import org.springframework.http.HttpHeaders;
028import org.springframework.http.HttpStatus;
029import org.springframework.lang.Nullable;
030import org.springframework.util.Assert;
031import org.springframework.util.CollectionUtils;
032
033/**
034 * {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}.
035 *
036 * @author Arjen Poutsma
037 * @author Rossen Stoyanchev
038 * @since 3.0
039 */
040public class ServletServerHttpResponse implements ServerHttpResponse {
041
042        private final HttpServletResponse servletResponse;
043
044        private final HttpHeaders headers;
045
046        private boolean headersWritten = false;
047
048        private boolean bodyUsed = false;
049
050
051        /**
052         * Construct a new instance of the ServletServerHttpResponse based on the given {@link HttpServletResponse}.
053         * @param servletResponse the servlet response
054         */
055        public ServletServerHttpResponse(HttpServletResponse servletResponse) {
056                Assert.notNull(servletResponse, "HttpServletResponse must not be null");
057                this.servletResponse = servletResponse;
058                this.headers = new ServletResponseHttpHeaders();
059        }
060
061
062        /**
063         * Return the {@code HttpServletResponse} this object is based on.
064         */
065        public HttpServletResponse getServletResponse() {
066                return this.servletResponse;
067        }
068
069        @Override
070        public void setStatusCode(HttpStatus status) {
071                Assert.notNull(status, "HttpStatus must not be null");
072                this.servletResponse.setStatus(status.value());
073        }
074
075        @Override
076        public HttpHeaders getHeaders() {
077                return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
078        }
079
080        @Override
081        public OutputStream getBody() throws IOException {
082                this.bodyUsed = true;
083                writeHeaders();
084                return this.servletResponse.getOutputStream();
085        }
086
087        @Override
088        public void flush() throws IOException {
089                writeHeaders();
090                if (this.bodyUsed) {
091                        this.servletResponse.flushBuffer();
092                }
093        }
094
095        @Override
096        public void close() {
097                writeHeaders();
098        }
099
100        private void writeHeaders() {
101                if (!this.headersWritten) {
102                        getHeaders().forEach((headerName, headerValues) -> {
103                                for (String headerValue : headerValues) {
104                                        this.servletResponse.addHeader(headerName, headerValue);
105                                }
106                        });
107                        // HttpServletResponse exposes some headers as properties: we should include those if not already present
108                        if (this.servletResponse.getContentType() == null && this.headers.getContentType() != null) {
109                                this.servletResponse.setContentType(this.headers.getContentType().toString());
110                        }
111                        if (this.servletResponse.getCharacterEncoding() == null && this.headers.getContentType() != null &&
112                                        this.headers.getContentType().getCharset() != null) {
113                                this.servletResponse.setCharacterEncoding(this.headers.getContentType().getCharset().name());
114                        }
115                        this.headersWritten = true;
116                }
117        }
118
119
120        /**
121         * Extends HttpHeaders with the ability to look up headers already present in
122         * the underlying HttpServletResponse.
123         *
124         * <p>The intent is merely to expose what is available through the HttpServletResponse
125         * i.e. the ability to look up specific header values by name. All other
126         * map-related operations (e.g. iteration, removal, etc) apply only to values
127         * added directly through HttpHeaders methods.
128         *
129         * @since 4.0.3
130         */
131        private class ServletResponseHttpHeaders extends HttpHeaders {
132
133                private static final long serialVersionUID = 3410708522401046302L;
134
135                @Override
136                public boolean containsKey(Object key) {
137                        return (super.containsKey(key) || (get(key) != null));
138                }
139
140                @Override
141                @Nullable
142                public String getFirst(String headerName) {
143                        String value = servletResponse.getHeader(headerName);
144                        if (value != null) {
145                                return value;
146                        }
147                        else {
148                                return super.getFirst(headerName);
149                        }
150                }
151
152                @Override
153                public List<String> get(Object key) {
154                        Assert.isInstanceOf(String.class, key, "Key must be a String-based header name");
155
156                        Collection<String> values1 = servletResponse.getHeaders((String) key);
157                        if (headersWritten) {
158                                return new ArrayList<>(values1);
159                        }
160                        boolean isEmpty1 = CollectionUtils.isEmpty(values1);
161
162                        List<String> values2 = super.get(key);
163                        boolean isEmpty2 = CollectionUtils.isEmpty(values2);
164
165                        if (isEmpty1 && isEmpty2) {
166                                return null;
167                        }
168
169                        List<String> values = new ArrayList<>();
170                        if (!isEmpty1) {
171                                values.addAll(values1);
172                        }
173                        if (!isEmpty2) {
174                                values.addAll(values2);
175                        }
176                        return values;
177                }
178        }
179
180}