001/*
002 * Copyright 2002-2016 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.server;
018
019import java.io.IOException;
020import java.io.OutputStream;
021import java.util.ArrayList;
022import java.util.Collection;
023import java.util.List;
024import java.util.Map;
025import javax.servlet.http.HttpServletResponse;
026
027import org.springframework.http.HttpHeaders;
028import org.springframework.http.HttpStatus;
029import org.springframework.util.Assert;
030import org.springframework.util.ClassUtils;
031import org.springframework.util.CollectionUtils;
032
033/**
034 * {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}.
035 *
036 * @author Arjen Poutsma
037 * @author Rossen Stoyanchev
038 * @since 3.0
039 */
040public class ServletServerHttpResponse implements ServerHttpResponse {
041
042        /** Checking for Servlet 3.0+ HttpServletResponse.getHeader(String) */
043        private static final boolean servlet3Present =
044                        ClassUtils.hasMethod(HttpServletResponse.class, "getHeader", String.class);
045
046
047        private final HttpServletResponse servletResponse;
048
049        private final HttpHeaders headers;
050
051        private boolean headersWritten = false;
052
053        private boolean bodyUsed = false;
054
055
056        /**
057         * Construct a new instance of the ServletServerHttpResponse based on the given {@link HttpServletResponse}.
058         * @param servletResponse the servlet response
059         */
060        public ServletServerHttpResponse(HttpServletResponse servletResponse) {
061                Assert.notNull(servletResponse, "HttpServletResponse must not be null");
062                this.servletResponse = servletResponse;
063                this.headers = (servlet3Present ? new ServletResponseHttpHeaders() : new HttpHeaders());
064        }
065
066
067        /**
068         * Return the {@code HttpServletResponse} this object is based on.
069         */
070        public HttpServletResponse getServletResponse() {
071                return this.servletResponse;
072        }
073
074        @Override
075        public void setStatusCode(HttpStatus status) {
076                Assert.notNull(status, "HttpStatus must not be null");
077                this.servletResponse.setStatus(status.value());
078        }
079
080        @Override
081        public HttpHeaders getHeaders() {
082                return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
083        }
084
085        @Override
086        public OutputStream getBody() throws IOException {
087                this.bodyUsed = true;
088                writeHeaders();
089                return this.servletResponse.getOutputStream();
090        }
091
092        @Override
093        public void flush() throws IOException {
094                writeHeaders();
095                if (this.bodyUsed) {
096                        this.servletResponse.flushBuffer();
097                }
098        }
099
100        @Override
101        public void close() {
102                writeHeaders();
103        }
104
105        private void writeHeaders() {
106                if (!this.headersWritten) {
107                        for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
108                                String headerName = entry.getKey();
109                                for (String headerValue : entry.getValue()) {
110                                        this.servletResponse.addHeader(headerName, headerValue);
111                                }
112                        }
113                        // HttpServletResponse exposes some headers as properties: we should include those if not already present
114                        if (this.servletResponse.getContentType() == null && this.headers.getContentType() != null) {
115                                this.servletResponse.setContentType(this.headers.getContentType().toString());
116                        }
117                        if (this.servletResponse.getCharacterEncoding() == null && this.headers.getContentType() != null &&
118                                        this.headers.getContentType().getCharset() != null) {
119                                this.servletResponse.setCharacterEncoding(this.headers.getContentType().getCharset().name());
120                        }
121                        this.headersWritten = true;
122                }
123        }
124
125
126        /**
127         * Extends HttpHeaders with the ability to look up headers already present in
128         * the underlying HttpServletResponse.
129         *
130         * <p>The intent is merely to expose what is available through the HttpServletResponse
131         * i.e. the ability to look up specific header values by name. All other
132         * map-related operations (e.g. iteration, removal, etc) apply only to values
133         * added directly through HttpHeaders methods.
134         *
135         * @since 4.0.3
136         */
137        private class ServletResponseHttpHeaders extends HttpHeaders {
138
139                private static final long serialVersionUID = 3410708522401046302L;
140
141                @Override
142                public boolean containsKey(Object key) {
143                        return (super.containsKey(key) || (get(key) != null));
144                }
145
146                @Override
147                public String getFirst(String headerName) {
148                        String value = servletResponse.getHeader(headerName);
149                        if (value != null) {
150                                return value;
151                        }
152                        else {
153                                return super.getFirst(headerName);
154                        }
155                }
156
157                @Override
158                public List<String> get(Object key) {
159                        Assert.isInstanceOf(String.class, key, "Key must be a String-based header name");
160
161                        Collection<String> values1 = servletResponse.getHeaders((String) key);
162                        boolean isEmpty1 = CollectionUtils.isEmpty(values1);
163
164                        List<String> values2 = super.get(key);
165                        boolean isEmpty2 = CollectionUtils.isEmpty(values2);
166
167                        if (isEmpty1 && isEmpty2) {
168                                return null;
169                        }
170
171                        List<String> values = new ArrayList<String>();
172                        if (!isEmpty1) {
173                                values.addAll(values1);
174                        }
175                        if (!isEmpty2) {
176                                values.addAll(values2);
177                        }
178                        return values;
179                }
180        }
181
182}