001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.mock.web;
018
019import java.io.IOException;
020import java.util.Collections;
021import java.util.Enumeration;
022import java.util.Iterator;
023import java.util.List;
024import java.util.Map;
025
026import javax.servlet.ServletContext;
027import javax.servlet.ServletException;
028import javax.servlet.http.Part;
029
030import org.springframework.http.HttpHeaders;
031import org.springframework.http.HttpMethod;
032import org.springframework.lang.Nullable;
033import org.springframework.util.Assert;
034import org.springframework.util.LinkedMultiValueMap;
035import org.springframework.util.MultiValueMap;
036import org.springframework.web.multipart.MultipartFile;
037import org.springframework.web.multipart.MultipartHttpServletRequest;
038
039/**
040 * Mock implementation of the
041 * {@link org.springframework.web.multipart.MultipartHttpServletRequest} interface.
042 *
043 * <p>As of Spring 5.0, this set of mocks is designed on a Servlet 4.0 baseline.
044 *
045 * <p>Useful for testing application controllers that access multipart uploads.
046 * {@link MockMultipartFile} can be used to populate these mock requests with files.
047 *
048 * @author Juergen Hoeller
049 * @author Eric Crampton
050 * @author Arjen Poutsma
051 * @since 2.0
052 * @see MockMultipartFile
053 */
054public class MockMultipartHttpServletRequest extends MockHttpServletRequest implements MultipartHttpServletRequest {
055
056        private final MultiValueMap<String, MultipartFile> multipartFiles = new LinkedMultiValueMap<>();
057
058
059        /**
060         * Create a new {@code MockMultipartHttpServletRequest} with a default
061         * {@link MockServletContext}.
062         * @see #MockMultipartHttpServletRequest(ServletContext)
063         */
064        public MockMultipartHttpServletRequest() {
065                this(null);
066        }
067
068        /**
069         * Create a new {@code MockMultipartHttpServletRequest} with the supplied {@link ServletContext}.
070         * @param servletContext the ServletContext that the request runs in
071         * (may be {@code null} to use a default {@link MockServletContext})
072         */
073        public MockMultipartHttpServletRequest(@Nullable ServletContext servletContext) {
074                super(servletContext);
075                setMethod("POST");
076                setContentType("multipart/form-data");
077        }
078
079
080        /**
081         * Add a file to this request. The parameter name from the multipart
082         * form is taken from the {@link MultipartFile#getName()}.
083         * @param file multipart file to be added
084         */
085        public void addFile(MultipartFile file) {
086                Assert.notNull(file, "MultipartFile must not be null");
087                this.multipartFiles.add(file.getName(), file);
088        }
089
090        @Override
091        public Iterator<String> getFileNames() {
092                return this.multipartFiles.keySet().iterator();
093        }
094
095        @Override
096        public MultipartFile getFile(String name) {
097                return this.multipartFiles.getFirst(name);
098        }
099
100        @Override
101        public List<MultipartFile> getFiles(String name) {
102                List<MultipartFile> multipartFiles = this.multipartFiles.get(name);
103                if (multipartFiles != null) {
104                        return multipartFiles;
105                }
106                else {
107                        return Collections.emptyList();
108                }
109        }
110
111        @Override
112        public Map<String, MultipartFile> getFileMap() {
113                return this.multipartFiles.toSingleValueMap();
114        }
115
116        @Override
117        public MultiValueMap<String, MultipartFile> getMultiFileMap() {
118                return new LinkedMultiValueMap<>(this.multipartFiles);
119        }
120
121        @Override
122        public String getMultipartContentType(String paramOrFileName) {
123                MultipartFile file = getFile(paramOrFileName);
124                if (file != null) {
125                        return file.getContentType();
126                }
127                try {
128                        Part part = getPart(paramOrFileName);
129                        if (part != null) {
130                                return part.getContentType();
131                        }
132                }
133                catch (ServletException | IOException ex) {
134                        // Should never happen (we're not actually parsing)
135                        throw new IllegalStateException(ex);
136                }
137                return null;
138        }
139
140        @Override
141        public HttpMethod getRequestMethod() {
142                return HttpMethod.resolve(getMethod());
143        }
144
145        @Override
146        public HttpHeaders getRequestHeaders() {
147                HttpHeaders headers = new HttpHeaders();
148                Enumeration<String> headerNames = getHeaderNames();
149                while (headerNames.hasMoreElements()) {
150                        String headerName = headerNames.nextElement();
151                        headers.put(headerName, Collections.list(getHeaders(headerName)));
152                }
153                return headers;
154        }
155
156        @Override
157        public HttpHeaders getMultipartHeaders(String paramOrFileName) {
158                String contentType = getMultipartContentType(paramOrFileName);
159                if (contentType != null) {
160                        HttpHeaders headers = new HttpHeaders();
161                        headers.add(HttpHeaders.CONTENT_TYPE, contentType);
162                        return headers;
163                }
164                else {
165                        return null;
166                }
167        }
168
169}