001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.codec.multipart;
018
019import java.nio.charset.Charset;
020import java.nio.charset.StandardCharsets;
021import java.util.ArrayList;
022import java.util.Arrays;
023import java.util.Collections;
024import java.util.HashMap;
025import java.util.List;
026import java.util.Map;
027import java.util.Optional;
028import java.util.concurrent.atomic.AtomicBoolean;
029import java.util.function.Supplier;
030
031import org.reactivestreams.Publisher;
032import reactor.core.publisher.Flux;
033import reactor.core.publisher.Mono;
034
035import org.springframework.core.ResolvableType;
036import org.springframework.core.ResolvableTypeProvider;
037import org.springframework.core.codec.CharSequenceEncoder;
038import org.springframework.core.codec.CodecException;
039import org.springframework.core.codec.Hints;
040import org.springframework.core.io.Resource;
041import org.springframework.core.io.buffer.DataBuffer;
042import org.springframework.core.io.buffer.DataBufferFactory;
043import org.springframework.core.io.buffer.DataBufferUtils;
044import org.springframework.core.io.buffer.PooledDataBuffer;
045import org.springframework.core.log.LogFormatUtils;
046import org.springframework.http.HttpEntity;
047import org.springframework.http.HttpHeaders;
048import org.springframework.http.MediaType;
049import org.springframework.http.ReactiveHttpOutputMessage;
050import org.springframework.http.codec.EncoderHttpMessageWriter;
051import org.springframework.http.codec.FormHttpMessageWriter;
052import org.springframework.http.codec.HttpMessageWriter;
053import org.springframework.http.codec.LoggingCodecSupport;
054import org.springframework.http.codec.ResourceHttpMessageWriter;
055import org.springframework.lang.Nullable;
056import org.springframework.util.Assert;
057import org.springframework.util.MimeTypeUtils;
058import org.springframework.util.MultiValueMap;
059
060/**
061 * {@link HttpMessageWriter} for writing a {@code MultiValueMap<String, ?>}
062 * as multipart form data, i.e. {@code "multipart/form-data"}, to the body
063 * of a request.
064 *
065 * <p>The serialization of individual parts is delegated to other writers.
066 * By default only {@link String} and {@link Resource} parts are supported but
067 * you can configure others through a constructor argument.
068 *
069 * <p>This writer can be configured with a {@link FormHttpMessageWriter} to
070 * delegate to. It is the preferred way of supporting both form data and
071 * multipart data (as opposed to registering each writer separately) so that
072 * when the {@link MediaType} is not specified and generics are not present on
073 * the target element type, we can inspect the values in the actual map and
074 * decide whether to write plain form data (String values only) or otherwise.
075 *
076 * @author Sebastien Deleuze
077 * @author Rossen Stoyanchev
078 * @since 5.0
079 * @see FormHttpMessageWriter
080 */
081public class MultipartHttpMessageWriter extends LoggingCodecSupport
082                implements HttpMessageWriter<MultiValueMap<String, ?>> {
083
084        /**
085         * THe default charset used by the writer.
086         */
087        public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
088
089        /** Suppress logging from individual part writers (full map logged at this level). */
090        private static final Map<String, Object> DEFAULT_HINTS = Hints.from(Hints.SUPPRESS_LOGGING_HINT, true);
091
092
093        private final List<HttpMessageWriter<?>> partWriters;
094
095        @Nullable
096        private final HttpMessageWriter<MultiValueMap<String, String>> formWriter;
097
098        private Charset charset = DEFAULT_CHARSET;
099
100        private final List<MediaType> supportedMediaTypes;
101
102
103        /**
104         * Constructor with a default list of part writers (String and Resource).
105         */
106        public MultipartHttpMessageWriter() {
107                this(Arrays.asList(
108                                new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly()),
109                                new ResourceHttpMessageWriter()
110                ));
111        }
112
113        /**
114         * Constructor with explicit list of writers for serializing parts.
115         */
116        public MultipartHttpMessageWriter(List<HttpMessageWriter<?>> partWriters) {
117                this(partWriters, new FormHttpMessageWriter());
118        }
119
120        /**
121         * Constructor with explicit list of writers for serializing parts and a
122         * writer for plain form data to fall back when no media type is specified
123         * and the actual map consists of String values only.
124         * @param partWriters the writers for serializing parts
125         * @param formWriter the fallback writer for form data, {@code null} by default
126         */
127        public MultipartHttpMessageWriter(List<HttpMessageWriter<?>> partWriters,
128                        @Nullable  HttpMessageWriter<MultiValueMap<String, String>> formWriter) {
129
130                this.partWriters = partWriters;
131                this.formWriter = formWriter;
132                this.supportedMediaTypes = initMediaTypes(formWriter);
133        }
134
135        private static List<MediaType> initMediaTypes(@Nullable HttpMessageWriter<?> formWriter) {
136                List<MediaType> result = new ArrayList<>(MultipartHttpMessageReader.MIME_TYPES);
137                if (formWriter != null) {
138                        result.addAll(formWriter.getWritableMediaTypes());
139                }
140                return Collections.unmodifiableList(result);
141        }
142
143
144        /**
145         * Return the configured part writers.
146         * @since 5.0.7
147         */
148        public List<HttpMessageWriter<?>> getPartWriters() {
149                return Collections.unmodifiableList(this.partWriters);
150        }
151
152
153        /**
154         * Return the configured form writer.
155         * @since 5.1.13
156         */
157        @Nullable
158        public HttpMessageWriter<MultiValueMap<String, String>> getFormWriter() {
159                return this.formWriter;
160        }
161
162        /**
163         * Set the character set to use for part headers such as
164         * "Content-Disposition" (and its filename parameter).
165         * <p>By default this is set to "UTF-8".
166         */
167        public void setCharset(Charset charset) {
168                Assert.notNull(charset, "Charset must not be null");
169                this.charset = charset;
170        }
171
172        /**
173         * Return the configured charset for part headers.
174         */
175        public Charset getCharset() {
176                return this.charset;
177        }
178
179
180        @Override
181        public List<MediaType> getWritableMediaTypes() {
182                return this.supportedMediaTypes;
183        }
184
185        @Override
186        public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) {
187                return (MultiValueMap.class.isAssignableFrom(elementType.toClass()) &&
188                                (mediaType == null ||
189                                                this.supportedMediaTypes.stream().anyMatch(element -> element.isCompatibleWith(mediaType))));
190        }
191
192        @Override
193        public Mono<Void> write(Publisher<? extends MultiValueMap<String, ?>> inputStream,
194                        ResolvableType elementType, @Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage,
195                        Map<String, Object> hints) {
196
197                return Mono.from(inputStream)
198                                .flatMap(map -> {
199                                        if (this.formWriter == null || isMultipart(map, mediaType)) {
200                                                return writeMultipart(map, outputMessage, mediaType, hints);
201                                        }
202                                        else {
203                                                @SuppressWarnings("unchecked")
204                                                Mono<MultiValueMap<String, String>> input = Mono.just((MultiValueMap<String, String>) map);
205                                                return this.formWriter.write(input, elementType, mediaType, outputMessage, hints);
206                                        }
207                                });
208        }
209
210        private boolean isMultipart(MultiValueMap<String, ?> map, @Nullable MediaType contentType) {
211                if (contentType != null) {
212                        return contentType.getType().equalsIgnoreCase("multipart");
213                }
214                for (List<?> values : map.values()) {
215                        for (Object value : values) {
216                                if (value != null && !(value instanceof String)) {
217                                        return true;
218                                }
219                        }
220                }
221                return false;
222        }
223
224        private Mono<Void> writeMultipart(MultiValueMap<String, ?> map,
225                        ReactiveHttpOutputMessage outputMessage, @Nullable MediaType mediaType, Map<String, Object> hints) {
226
227                byte[] boundary = generateMultipartBoundary();
228
229                Map<String, String> params = new HashMap<>();
230                if (mediaType != null) {
231                        params.putAll(mediaType.getParameters());
232                }
233                params.put("boundary", new String(boundary, StandardCharsets.US_ASCII));
234                params.put("charset", getCharset().name());
235
236                mediaType = (mediaType != null ? mediaType : MediaType.MULTIPART_FORM_DATA);
237                mediaType = new MediaType(mediaType, params);
238
239                outputMessage.getHeaders().setContentType(mediaType);
240
241                LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Encoding " +
242                                (isEnableLoggingRequestDetails() ?
243                                                LogFormatUtils.formatValue(map, !traceOn) :
244                                                "parts " + map.keySet() + " (content masked)"));
245
246                DataBufferFactory bufferFactory = outputMessage.bufferFactory();
247
248                Flux<DataBuffer> body = Flux.fromIterable(map.entrySet())
249                                .concatMap(entry -> encodePartValues(boundary, entry.getKey(), entry.getValue(), bufferFactory))
250                                .concatWith(generateLastLine(boundary, bufferFactory))
251                                .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
252
253                return outputMessage.writeWith(body);
254        }
255
256        /**
257         * Generate a multipart boundary.
258         * <p>By default delegates to {@link MimeTypeUtils#generateMultipartBoundary()}.
259         */
260        protected byte[] generateMultipartBoundary() {
261                return MimeTypeUtils.generateMultipartBoundary();
262        }
263
264        private Flux<DataBuffer> encodePartValues(
265                        byte[] boundary, String name, List<?> values, DataBufferFactory bufferFactory) {
266
267                return Flux.fromIterable(values)
268                                .concatMap(value -> encodePart(boundary, name, value, bufferFactory));
269        }
270
271        @SuppressWarnings("unchecked")
272        private <T> Flux<DataBuffer> encodePart(byte[] boundary, String name, T value, DataBufferFactory bufferFactory) {
273                MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(bufferFactory, getCharset());
274                HttpHeaders outputHeaders = outputMessage.getHeaders();
275
276                T body;
277                ResolvableType resolvableType = null;
278                if (value instanceof HttpEntity) {
279                        HttpEntity<T> httpEntity = (HttpEntity<T>) value;
280                        outputHeaders.putAll(httpEntity.getHeaders());
281                        body = httpEntity.getBody();
282                        Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body");
283                        if (httpEntity instanceof ResolvableTypeProvider) {
284                                resolvableType = ((ResolvableTypeProvider) httpEntity).getResolvableType();
285                        }
286                }
287                else {
288                        body = value;
289                }
290                if (resolvableType == null) {
291                        resolvableType = ResolvableType.forClass(body.getClass());
292                }
293
294                if (!outputHeaders.containsKey(HttpHeaders.CONTENT_DISPOSITION)) {
295                        if (body instanceof Resource) {
296                                outputHeaders.setContentDispositionFormData(name, ((Resource) body).getFilename());
297                        }
298                        else if (resolvableType.resolve() == Resource.class) {
299                                body = (T) Mono.from((Publisher<?>) body).doOnNext(o -> outputHeaders
300                                                .setContentDispositionFormData(name, ((Resource) o).getFilename()));
301                        }
302                        else {
303                                outputHeaders.setContentDispositionFormData(name, null);
304                        }
305                }
306
307                MediaType contentType = outputHeaders.getContentType();
308
309                final ResolvableType finalBodyType = resolvableType;
310                Optional<HttpMessageWriter<?>> writer = this.partWriters.stream()
311                                .filter(partWriter -> partWriter.canWrite(finalBodyType, contentType))
312                                .findFirst();
313
314                if (!writer.isPresent()) {
315                        return Flux.error(new CodecException("No suitable writer found for part: " + name));
316                }
317
318                Publisher<T> bodyPublisher =
319                                body instanceof Publisher ? (Publisher<T>) body : Mono.just(body);
320
321                // The writer will call MultipartHttpOutputMessage#write which doesn't actually write
322                // but only stores the body Flux and returns Mono.empty().
323
324                Mono<Void> partContentReady = ((HttpMessageWriter<T>) writer.get())
325                                .write(bodyPublisher, resolvableType, contentType, outputMessage, DEFAULT_HINTS);
326
327                // After partContentReady, we can access the part content from MultipartHttpOutputMessage
328                // and use it for writing to the actual request body
329
330                Flux<DataBuffer> partContent = partContentReady.thenMany(Flux.defer(outputMessage::getBody));
331
332                return Flux.concat(
333                                generateBoundaryLine(boundary, bufferFactory),
334                                partContent,
335                                generateNewLine(bufferFactory));
336        }
337
338
339        private Mono<DataBuffer> generateBoundaryLine(byte[] boundary, DataBufferFactory bufferFactory) {
340                return Mono.fromCallable(() -> {
341                        DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 4);
342                        buffer.write((byte)'-');
343                        buffer.write((byte)'-');
344                        buffer.write(boundary);
345                        buffer.write((byte)'\r');
346                        buffer.write((byte)'\n');
347                        return buffer;
348                });
349        }
350
351        private Mono<DataBuffer> generateNewLine(DataBufferFactory bufferFactory) {
352                return Mono.fromCallable(() -> {
353                        DataBuffer buffer = bufferFactory.allocateBuffer(2);
354                        buffer.write((byte)'\r');
355                        buffer.write((byte)'\n');
356                        return buffer;
357                });
358        }
359
360        private Mono<DataBuffer> generateLastLine(byte[] boundary, DataBufferFactory bufferFactory) {
361                return Mono.fromCallable(() -> {
362                        DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 6);
363                        buffer.write((byte)'-');
364                        buffer.write((byte)'-');
365                        buffer.write(boundary);
366                        buffer.write((byte)'-');
367                        buffer.write((byte)'-');
368                        buffer.write((byte)'\r');
369                        buffer.write((byte)'\n');
370                        return buffer;
371                });
372        }
373
374
375        private static class MultipartHttpOutputMessage implements ReactiveHttpOutputMessage {
376
377                private final DataBufferFactory bufferFactory;
378
379                private final Charset charset;
380
381                private final HttpHeaders headers = new HttpHeaders();
382
383                private final AtomicBoolean committed = new AtomicBoolean();
384
385                @Nullable
386                private Flux<DataBuffer> body;
387
388                public MultipartHttpOutputMessage(DataBufferFactory bufferFactory, Charset charset) {
389                        this.bufferFactory = bufferFactory;
390                        this.charset = charset;
391                }
392
393                @Override
394                public HttpHeaders getHeaders() {
395                        return (this.body != null ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
396                }
397
398                @Override
399                public DataBufferFactory bufferFactory() {
400                        return this.bufferFactory;
401                }
402
403                @Override
404                public void beforeCommit(Supplier<? extends Mono<Void>> action) {
405                        this.committed.set(true);
406                }
407
408                @Override
409                public boolean isCommitted() {
410                        return this.committed.get();
411                }
412
413                @Override
414                public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
415                        if (this.body != null) {
416                                return Mono.error(new IllegalStateException("Multiple calls to writeWith() not supported"));
417                        }
418                        this.body = generateHeaders().concatWith(body);
419
420                        // We don't actually want to write (just save the body Flux)
421                        return Mono.empty();
422                }
423
424                private Mono<DataBuffer> generateHeaders() {
425                        return Mono.fromCallable(() -> {
426                                DataBuffer buffer = this.bufferFactory.allocateBuffer();
427                                for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
428                                        byte[] headerName = entry.getKey().getBytes(this.charset);
429                                        for (String headerValueString : entry.getValue()) {
430                                                byte[] headerValue = headerValueString.getBytes(this.charset);
431                                                buffer.write(headerName);
432                                                buffer.write((byte)':');
433                                                buffer.write((byte)' ');
434                                                buffer.write(headerValue);
435                                                buffer.write((byte)'\r');
436                                                buffer.write((byte)'\n');
437                                        }
438                                }
439                                buffer.write((byte)'\r');
440                                buffer.write((byte)'\n');
441                                return buffer;
442                        });
443                }
444
445                @Override
446                public Mono<Void> writeAndFlushWith(Publisher<? extends Publisher<? extends DataBuffer>> body) {
447                        return Mono.error(new UnsupportedOperationException());
448                }
449
450                public Flux<DataBuffer> getBody() {
451                        return (this.body != null ? this.body :
452                                        Flux.error(new IllegalStateException("Body has not been written yet")));
453                }
454
455                @Override
456                public Mono<Void> setComplete() {
457                        return Mono.error(new UnsupportedOperationException());
458                }
459        }
460
461}