001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.http.codec.protobuf;
018
019import java.io.IOException;
020import java.lang.reflect.Method;
021import java.nio.ByteBuffer;
022import java.util.ArrayList;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ConcurrentMap;
026import java.util.function.Function;
027
028import com.google.protobuf.CodedInputStream;
029import com.google.protobuf.ExtensionRegistry;
030import com.google.protobuf.Message;
031import org.reactivestreams.Publisher;
032import reactor.core.publisher.Flux;
033import reactor.core.publisher.Mono;
034
035import org.springframework.core.ResolvableType;
036import org.springframework.core.codec.Decoder;
037import org.springframework.core.codec.DecodingException;
038import org.springframework.core.io.buffer.DataBuffer;
039import org.springframework.core.io.buffer.DataBufferLimitException;
040import org.springframework.core.io.buffer.DataBufferUtils;
041import org.springframework.lang.Nullable;
042import org.springframework.util.Assert;
043import org.springframework.util.ConcurrentReferenceHashMap;
044import org.springframework.util.MimeType;
045
046/**
047 * A {@code Decoder} that reads {@link com.google.protobuf.Message}s using
048 * <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
049 *
050 * <p>Flux deserialized via
051 * {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use
052 * <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">
053 * delimited Protobuf messages</a> with the size of each message specified before
054 * the message itself. Single values deserialized via
055 * {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected
056 * to use regular Protobuf message format (without the size prepended before
057 * the message).
058 *
059 * <p>Notice that default instance of Protobuf message produces empty byte
060 * array, so {@code Mono.just(Msg.getDefaultInstance())} sent over the network
061 * will be deserialized as an empty {@link Mono}.
062 *
063 * <p>To generate {@code Message} Java classes, you need to install the
064 * {@code protoc} binary.
065 *
066 * <p>This decoder requires Protobuf 3 or higher, and supports
067 * {@code "application/x-protobuf"} and {@code "application/octet-stream"} with
068 * the official {@code "com.google.protobuf:protobuf-java"} library.
069 *
070 * @author S茅bastien Deleuze
071 * @since 5.1
072 * @see ProtobufEncoder
073 */
074public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Message> {
075
076        /** The default max size for aggregating messages. */
077        protected static final int DEFAULT_MESSAGE_MAX_SIZE = 256 * 1024;
078
079        private static final ConcurrentMap<Class<?>, Method> methodCache = new ConcurrentReferenceHashMap<>();
080
081
082        private final ExtensionRegistry extensionRegistry;
083
084        private int maxMessageSize = DEFAULT_MESSAGE_MAX_SIZE;
085
086
087        /**
088         * Construct a new {@code ProtobufDecoder}.
089         */
090        public ProtobufDecoder() {
091                this(ExtensionRegistry.newInstance());
092        }
093
094        /**
095         * Construct a new {@code ProtobufDecoder} with an initializer that allows the
096         * registration of message extensions.
097         * @param extensionRegistry a message extension registry
098         */
099        public ProtobufDecoder(ExtensionRegistry extensionRegistry) {
100                Assert.notNull(extensionRegistry, "ExtensionRegistry must not be null");
101                this.extensionRegistry = extensionRegistry;
102        }
103
104
105        /**
106         * The max size allowed per message.
107         * <p>By default, this is set to 256K.
108         * @param maxMessageSize the max size per message, or -1 for unlimited
109         */
110        public void setMaxMessageSize(int maxMessageSize) {
111                this.maxMessageSize = maxMessageSize;
112        }
113
114        /**
115         * Return the {@link #setMaxMessageSize configured} message size limit.
116         * @since 5.1.11
117         */
118        public int getMaxMessageSize() {
119                return this.maxMessageSize;
120        }
121
122
123        @Override
124        public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) {
125                return Message.class.isAssignableFrom(elementType.toClass()) && supportsMimeType(mimeType);
126        }
127
128        @Override
129        public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType,
130                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
131
132                MessageDecoderFunction decoderFunction =
133                                new MessageDecoderFunction(elementType, this.maxMessageSize);
134
135                return Flux.from(inputStream)
136                                .flatMapIterable(decoderFunction)
137                                .doOnTerminate(decoderFunction::discard);
138        }
139
140        @Override
141        public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType,
142                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
143
144                return DataBufferUtils.join(inputStream, this.maxMessageSize)
145                                .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints));
146        }
147
148        @Override
149        public Message decode(DataBuffer dataBuffer, ResolvableType targetType,
150                        @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) throws DecodingException {
151
152                try {
153                        Message.Builder builder = getMessageBuilder(targetType.toClass());
154                        ByteBuffer buffer = dataBuffer.asByteBuffer();
155                        builder.mergeFrom(CodedInputStream.newInstance(buffer), this.extensionRegistry);
156                        return builder.build();
157                }
158                catch (IOException ex) {
159                        throw new DecodingException("I/O error while parsing input stream", ex);
160                }
161                catch (Exception ex) {
162                        throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex);
163                }
164                finally {
165                        DataBufferUtils.release(dataBuffer);
166                }
167        }
168
169
170        /**
171         * Create a new {@code Message.Builder} instance for the given class.
172         * <p>This method uses a ConcurrentHashMap for caching method lookups.
173         */
174        private static Message.Builder getMessageBuilder(Class<?> clazz) throws Exception {
175                Method method = methodCache.get(clazz);
176                if (method == null) {
177                        method = clazz.getMethod("newBuilder");
178                        methodCache.put(clazz, method);
179                }
180                return (Message.Builder) method.invoke(clazz);
181        }
182
183        @Override
184        public List<MimeType> getDecodableMimeTypes() {
185                return getMimeTypes();
186        }
187
188
189        private class MessageDecoderFunction implements Function<DataBuffer, Iterable<? extends Message>> {
190
191                private final ResolvableType elementType;
192
193                private final int maxMessageSize;
194
195                @Nullable
196                private DataBuffer output;
197
198                private int messageBytesToRead;
199
200                private int offset;
201
202
203                public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
204                        this.elementType = elementType;
205                        this.maxMessageSize = maxMessageSize;
206                }
207
208
209                @Override
210                public Iterable<? extends Message> apply(DataBuffer input) {
211                        try {
212                                List<Message> messages = new ArrayList<>();
213                                int remainingBytesToRead;
214                                int chunkBytesToRead;
215
216                                do {
217                                        if (this.output == null) {
218                                                if (!readMessageSize(input)) {
219                                                        return messages;
220                                                }
221                                                if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) {
222                                                        throw new DataBufferLimitException(
223                                                                        "The number of bytes to read for message " +
224                                                                                        "(" + this.messageBytesToRead + ") exceeds " +
225                                                                                        "the configured limit (" + this.maxMessageSize + ")");
226                                                }
227                                                this.output = input.factory().allocateBuffer(this.messageBytesToRead);
228                                        }
229
230                                        chunkBytesToRead = Math.min(this.messageBytesToRead, input.readableByteCount());
231                                        remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
232
233                                        byte[] bytesToWrite = new byte[chunkBytesToRead];
234                                        input.read(bytesToWrite, 0, chunkBytesToRead);
235                                        this.output.write(bytesToWrite);
236                                        this.messageBytesToRead -= chunkBytesToRead;
237
238                                        if (this.messageBytesToRead == 0) {
239                                                CodedInputStream stream = CodedInputStream.newInstance(this.output.asByteBuffer());
240                                                DataBufferUtils.release(this.output);
241                                                this.output = null;
242                                                Message message = getMessageBuilder(this.elementType.toClass())
243                                                                .mergeFrom(stream, extensionRegistry)
244                                                                .build();
245                                                messages.add(message);
246                                        }
247                                } while (remainingBytesToRead > 0);
248                                return messages;
249                        }
250                        catch (DecodingException ex) {
251                                throw ex;
252                        }
253                        catch (IOException ex) {
254                                throw new DecodingException("I/O error while parsing input stream", ex);
255                        }
256                        catch (Exception ex) {
257                                throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex);
258                        }
259                        finally {
260                                DataBufferUtils.release(input);
261                        }
262                }
263
264                /**
265                 * Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and
266                 * {@code offset} fields if needed to allow processing of upcoming chunks.
267                 * Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)}
268                 *
269                 * @return {code true} when the message size is parsed successfully, {code false} when the message size is
270                 * truncated
271                 * @see <a href ="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
272                 */
273                private boolean readMessageSize(DataBuffer input) {
274                        if (this.offset == 0) {
275                                if (input.readableByteCount() == 0) {
276                                        return false;
277                                }
278                                int firstByte = input.read();
279                                if ((firstByte & 0x80) == 0) {
280                                        this.messageBytesToRead = firstByte;
281                                        return true;
282                                }
283                                this.messageBytesToRead = firstByte & 0x7f;
284                                this.offset = 7;
285                        }
286
287                        if (this.offset < 32) {
288                                for (; this.offset < 32; this.offset += 7) {
289                                        if (input.readableByteCount() == 0) {
290                                                return false;
291                                        }
292                                        final int b = input.read();
293                                        this.messageBytesToRead |= (b & 0x7f) << offset;
294                                        if ((b & 0x80) == 0) {
295                                                this.offset = 0;
296                                                return true;
297                                        }
298                                }
299                        }
300                        // Keep reading up to 64 bits.
301                        for (; this.offset < 64; this.offset += 7) {
302                                if (input.readableByteCount() == 0) {
303                                        return false;
304                                }
305                                final int b = input.read();
306                                if ((b & 0x80) == 0) {
307                                        this.offset = 0;
308                                        return true;
309                                }
310                        }
311                        this.offset = 0;
312                        throw new DecodingException("Cannot parse message size: malformed varint");
313                }
314
315                public void discard() {
316                        if (this.output != null) {
317                                DataBufferUtils.release(this.output);
318                        }
319                }
320        }
321
322}