001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.codec.xml;
018
019import java.util.ArrayList;
020import java.util.Iterator;
021import java.util.List;
022import java.util.Map;
023import java.util.function.BiConsumer;
024import java.util.function.Function;
025
026import javax.xml.XMLConstants;
027import javax.xml.bind.JAXBElement;
028import javax.xml.bind.JAXBException;
029import javax.xml.bind.UnmarshalException;
030import javax.xml.bind.Unmarshaller;
031import javax.xml.bind.annotation.XmlRootElement;
032import javax.xml.bind.annotation.XmlSchema;
033import javax.xml.bind.annotation.XmlType;
034import javax.xml.namespace.QName;
035import javax.xml.stream.XMLEventReader;
036import javax.xml.stream.XMLInputFactory;
037import javax.xml.stream.XMLStreamException;
038import javax.xml.stream.events.XMLEvent;
039
040import org.reactivestreams.Publisher;
041import reactor.core.Exceptions;
042import reactor.core.publisher.Flux;
043import reactor.core.publisher.Mono;
044import reactor.core.publisher.SynchronousSink;
045
046import org.springframework.core.ResolvableType;
047import org.springframework.core.codec.AbstractDecoder;
048import org.springframework.core.codec.CodecException;
049import org.springframework.core.codec.DecodingException;
050import org.springframework.core.codec.Hints;
051import org.springframework.core.io.buffer.DataBuffer;
052import org.springframework.core.io.buffer.DataBufferLimitException;
053import org.springframework.core.io.buffer.DataBufferUtils;
054import org.springframework.core.log.LogFormatUtils;
055import org.springframework.http.MediaType;
056import org.springframework.lang.Nullable;
057import org.springframework.util.Assert;
058import org.springframework.util.ClassUtils;
059import org.springframework.util.MimeType;
060import org.springframework.util.MimeTypeUtils;
061import org.springframework.util.xml.StaxUtils;
062
063/**
064 * Decode from a bytes stream containing XML elements to a stream of
065 * {@code Object}s (POJOs).
066 *
067 * @author Sebastien Deleuze
068 * @author Arjen Poutsma
069 * @since 5.0
070 * @see Jaxb2XmlEncoder
071 */
072public class Jaxb2XmlDecoder extends AbstractDecoder<Object> {
073
074        /**
075         * The default value for JAXB annotations.
076         * @see XmlRootElement#name()
077         * @see XmlRootElement#namespace()
078         * @see XmlType#name()
079         * @see XmlType#namespace()
080         */
081        private static final String JAXB_DEFAULT_ANNOTATION_VALUE = "##default";
082
083        private static final XMLInputFactory inputFactory = StaxUtils.createDefensiveInputFactory();
084
085
086        private final XmlEventDecoder xmlEventDecoder = new XmlEventDecoder();
087
088        private final JaxbContextContainer jaxbContexts = new JaxbContextContainer();
089
090        private Function<Unmarshaller, Unmarshaller> unmarshallerProcessor = Function.identity();
091
092        private int maxInMemorySize = 256 * 1024;
093
094
095        public Jaxb2XmlDecoder() {
096                super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML, new MediaType("application", "*+xml"));
097        }
098
099        /**
100         * Create a {@code Jaxb2XmlDecoder} with the specified MIME types.
101         * @param supportedMimeTypes supported MIME types
102         * @since 5.1.9
103         */
104        public Jaxb2XmlDecoder(MimeType... supportedMimeTypes) {
105                super(supportedMimeTypes);
106        }
107
108
109        /**
110         * Configure a processor function to customize Unmarshaller instances.
111         * @param processor the function to use
112         * @since 5.1.3
113         */
114        public void setUnmarshallerProcessor(Function<Unmarshaller, Unmarshaller> processor) {
115                this.unmarshallerProcessor = this.unmarshallerProcessor.andThen(processor);
116        }
117
118        /**
119         * Return the configured processor for customizing Unmarshaller instances.
120         * @since 5.1.3
121         */
122        public Function<Unmarshaller, Unmarshaller> getUnmarshallerProcessor() {
123                return this.unmarshallerProcessor;
124        }
125
126        /**
127         * Set the max number of bytes that can be buffered by this decoder.
128         * This is either the size of the entire input when decoding as a whole, or when
129         * using async parsing with Aalto XML, it is the size of one top-level XML tree.
130         * When the limit is exceeded, {@link DataBufferLimitException} is raised.
131         * <p>By default this is set to 256K.
132         * @param byteCount the max number of bytes to buffer, or -1 for unlimited
133         * @since 5.1.11
134         */
135        public void setMaxInMemorySize(int byteCount) {
136                this.maxInMemorySize = byteCount;
137                this.xmlEventDecoder.setMaxInMemorySize(byteCount);
138        }
139
140        /**
141         * Return the {@link #setMaxInMemorySize configured} byte count limit.
142         * @since 5.1.11
143         */
144        public int getMaxInMemorySize() {
145                return this.maxInMemorySize;
146        }
147
148
149        @Override
150        public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) {
151                Class<?> outputClass = elementType.toClass();
152                return (outputClass.isAnnotationPresent(XmlRootElement.class) ||
153                                outputClass.isAnnotationPresent(XmlType.class)) && super.canDecode(elementType, mimeType);
154        }
155
156        @Override
157        public Flux<Object> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType,
158                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
159
160                Flux<XMLEvent> xmlEventFlux = this.xmlEventDecoder.decode(
161                                inputStream, ResolvableType.forClass(XMLEvent.class), mimeType, hints);
162
163                Class<?> outputClass = elementType.toClass();
164                QName typeName = toQName(outputClass);
165                Flux<List<XMLEvent>> splitEvents = split(xmlEventFlux, typeName);
166
167                return splitEvents.map(events -> {
168                        Object value = unmarshal(events, outputClass);
169                        LogFormatUtils.traceDebug(logger, traceOn -> {
170                                String formatted = LogFormatUtils.formatValue(value, !traceOn);
171                                return Hints.getLogPrefix(hints) + "Decoded [" + formatted + "]";
172                        });
173                        return value;
174                });
175        }
176
177        @Override
178        @SuppressWarnings({"rawtypes", "unchecked", "cast"})  // XMLEventReader is Iterator<Object> on JDK 9
179        public Mono<Object> decodeToMono(Publisher<DataBuffer> input, ResolvableType elementType,
180                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
181
182                return DataBufferUtils.join(input, this.maxInMemorySize)
183                                .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints));
184        }
185
186        @Override
187        @SuppressWarnings({"rawtypes", "unchecked", "cast"})  // XMLEventReader is Iterator<Object> on JDK 9
188        public Object decode(DataBuffer dataBuffer, ResolvableType targetType,
189                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) throws DecodingException {
190
191                try {
192                        Iterator eventReader = inputFactory.createXMLEventReader(dataBuffer.asInputStream());
193                        List<XMLEvent> events = new ArrayList<>();
194                        eventReader.forEachRemaining(event -> events.add((XMLEvent) event));
195                        return unmarshal(events, targetType.toClass());
196                }
197                catch (XMLStreamException ex) {
198                        throw Exceptions.propagate(ex);
199                }
200                catch (Throwable ex) {
201                        ex = (ex.getCause() instanceof XMLStreamException ? ex.getCause() : ex);
202                        throw Exceptions.propagate(ex);
203                }
204                finally {
205                        DataBufferUtils.release(dataBuffer);
206                }
207        }
208
209        private Object unmarshal(List<XMLEvent> events, Class<?> outputClass) {
210                try {
211                        Unmarshaller unmarshaller = initUnmarshaller(outputClass);
212                        XMLEventReader eventReader = StaxUtils.createXMLEventReader(events);
213                        if (outputClass.isAnnotationPresent(XmlRootElement.class)) {
214                                return unmarshaller.unmarshal(eventReader);
215                        }
216                        else {
217                                JAXBElement<?> jaxbElement = unmarshaller.unmarshal(eventReader, outputClass);
218                                return jaxbElement.getValue();
219                        }
220                }
221                catch (UnmarshalException ex) {
222                        throw new DecodingException("Could not unmarshal XML to " + outputClass, ex);
223                }
224                catch (JAXBException ex) {
225                        throw new CodecException("Invalid JAXB configuration", ex);
226                }
227        }
228
229        private Unmarshaller initUnmarshaller(Class<?> outputClass) throws CodecException, JAXBException {
230                Unmarshaller unmarshaller = this.jaxbContexts.createUnmarshaller(outputClass);
231                return this.unmarshallerProcessor.apply(unmarshaller);
232        }
233
234        /**
235         * Returns the qualified name for the given class, according to the mapping rules
236         * in the JAXB specification.
237         */
238        QName toQName(Class<?> outputClass) {
239                String localPart;
240                String namespaceUri;
241
242                if (outputClass.isAnnotationPresent(XmlRootElement.class)) {
243                        XmlRootElement annotation = outputClass.getAnnotation(XmlRootElement.class);
244                        localPart = annotation.name();
245                        namespaceUri = annotation.namespace();
246                }
247                else if (outputClass.isAnnotationPresent(XmlType.class)) {
248                        XmlType annotation = outputClass.getAnnotation(XmlType.class);
249                        localPart = annotation.name();
250                        namespaceUri = annotation.namespace();
251                }
252                else {
253                        throw new IllegalArgumentException("Output class [" + outputClass.getName() +
254                                        "] is neither annotated with @XmlRootElement nor @XmlType");
255                }
256
257                if (JAXB_DEFAULT_ANNOTATION_VALUE.equals(localPart)) {
258                        localPart = ClassUtils.getShortNameAsProperty(outputClass);
259                }
260                if (JAXB_DEFAULT_ANNOTATION_VALUE.equals(namespaceUri)) {
261                        Package outputClassPackage = outputClass.getPackage();
262                        if (outputClassPackage != null && outputClassPackage.isAnnotationPresent(XmlSchema.class)) {
263                                XmlSchema annotation = outputClassPackage.getAnnotation(XmlSchema.class);
264                                namespaceUri = annotation.namespace();
265                        }
266                        else {
267                                namespaceUri = XMLConstants.NULL_NS_URI;
268                        }
269                }
270                return new QName(namespaceUri, localPart);
271        }
272
273        /**
274         * Split a flux of {@link XMLEvent XMLEvents} into a flux of XMLEvent lists, one list
275         * for each branch of the tree that starts with the given qualified name.
276         * That is, given the XMLEvents shown {@linkplain XmlEventDecoder here},
277         * and the {@code desiredName} "{@code child}", this method returns a flux
278         * of two lists, each of which containing the events of a particular branch
279         * of the tree that starts with "{@code child}".
280         * <ol>
281         * <li>The first list, dealing with the first branch of the tree:
282         * <ol>
283         * <li>{@link javax.xml.stream.events.StartElement} {@code child}</li>
284         * <li>{@link javax.xml.stream.events.Characters} {@code foo}</li>
285         * <li>{@link javax.xml.stream.events.EndElement} {@code child}</li>
286         * </ol>
287         * <li>The second list, dealing with the second branch of the tree:
288         * <ol>
289         * <li>{@link javax.xml.stream.events.StartElement} {@code child}</li>
290         * <li>{@link javax.xml.stream.events.Characters} {@code bar}</li>
291         * <li>{@link javax.xml.stream.events.EndElement} {@code child}</li>
292         * </ol>
293         * </li>
294         * </ol>
295         */
296        Flux<List<XMLEvent>> split(Flux<XMLEvent> xmlEventFlux, QName desiredName) {
297                return xmlEventFlux.handle(new SplitHandler(desiredName));
298        }
299
300
301        private static class SplitHandler implements BiConsumer<XMLEvent, SynchronousSink<List<XMLEvent>>> {
302
303                private final QName desiredName;
304
305                @Nullable
306                private List<XMLEvent> events;
307
308                private int elementDepth = 0;
309
310                private int barrier = Integer.MAX_VALUE;
311
312                public SplitHandler(QName desiredName) {
313                        this.desiredName = desiredName;
314                }
315
316                @Override
317                public void accept(XMLEvent event, SynchronousSink<List<XMLEvent>> sink) {
318                        if (event.isStartElement()) {
319                                if (this.barrier == Integer.MAX_VALUE) {
320                                        QName startElementName = event.asStartElement().getName();
321                                        if (this.desiredName.equals(startElementName)) {
322                                                this.events = new ArrayList<>();
323                                                this.barrier = this.elementDepth;
324                                        }
325                                }
326                                this.elementDepth++;
327                        }
328                        if (this.elementDepth > this.barrier) {
329                                Assert.state(this.events != null, "No XMLEvent List");
330                                this.events.add(event);
331                        }
332                        if (event.isEndElement()) {
333                                this.elementDepth--;
334                                if (this.elementDepth == this.barrier) {
335                                        this.barrier = Integer.MAX_VALUE;
336                                        Assert.state(this.events != null, "No XMLEvent List");
337                                        sink.next(this.events);
338                                }
339                        }
340                }
341        }
342
343}