001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.http.codec.protobuf; 018 019import java.io.IOException; 020import java.lang.reflect.Method; 021import java.nio.ByteBuffer; 022import java.util.ArrayList; 023import java.util.List; 024import java.util.Map; 025import java.util.concurrent.ConcurrentMap; 026import java.util.function.Function; 027 028import com.google.protobuf.CodedInputStream; 029import com.google.protobuf.ExtensionRegistry; 030import com.google.protobuf.Message; 031import org.reactivestreams.Publisher; 032import reactor.core.publisher.Flux; 033import reactor.core.publisher.Mono; 034 035import org.springframework.core.ResolvableType; 036import org.springframework.core.codec.Decoder; 037import org.springframework.core.codec.DecodingException; 038import org.springframework.core.io.buffer.DataBuffer; 039import org.springframework.core.io.buffer.DataBufferLimitException; 040import org.springframework.core.io.buffer.DataBufferUtils; 041import org.springframework.lang.Nullable; 042import org.springframework.util.Assert; 043import org.springframework.util.ConcurrentReferenceHashMap; 044import org.springframework.util.MimeType; 045 046/** 047 * A {@code Decoder} that reads {@link com.google.protobuf.Message}s using 048 * <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>. 049 * 050 * <p>Flux deserialized via 051 * {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use 052 * <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming"> 053 * delimited Protobuf messages</a> with the size of each message specified before 054 * the message itself. Single values deserialized via 055 * {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected 056 * to use regular Protobuf message format (without the size prepended before 057 * the message). 058 * 059 * <p>Notice that default instance of Protobuf message produces empty byte 060 * array, so {@code Mono.just(Msg.getDefaultInstance())} sent over the network 061 * will be deserialized as an empty {@link Mono}. 062 * 063 * <p>To generate {@code Message} Java classes, you need to install the 064 * {@code protoc} binary. 065 * 066 * <p>This decoder requires Protobuf 3 or higher, and supports 067 * {@code "application/x-protobuf"} and {@code "application/octet-stream"} with 068 * the official {@code "com.google.protobuf:protobuf-java"} library. 069 * 070 * @author S茅bastien Deleuze 071 * @since 5.1 072 * @see ProtobufEncoder 073 */ 074public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Message> { 075 076 /** The default max size for aggregating messages. */ 077 protected static final int DEFAULT_MESSAGE_MAX_SIZE = 256 * 1024; 078 079 private static final ConcurrentMap<Class<?>, Method> methodCache = new ConcurrentReferenceHashMap<>(); 080 081 082 private final ExtensionRegistry extensionRegistry; 083 084 private int maxMessageSize = DEFAULT_MESSAGE_MAX_SIZE; 085 086 087 /** 088 * Construct a new {@code ProtobufDecoder}. 089 */ 090 public ProtobufDecoder() { 091 this(ExtensionRegistry.newInstance()); 092 } 093 094 /** 095 * Construct a new {@code ProtobufDecoder} with an initializer that allows the 096 * registration of message extensions. 097 * @param extensionRegistry a message extension registry 098 */ 099 public ProtobufDecoder(ExtensionRegistry extensionRegistry) { 100 Assert.notNull(extensionRegistry, "ExtensionRegistry must not be null"); 101 this.extensionRegistry = extensionRegistry; 102 } 103 104 105 /** 106 * The max size allowed per message. 107 * <p>By default, this is set to 256K. 108 * @param maxMessageSize the max size per message, or -1 for unlimited 109 */ 110 public void setMaxMessageSize(int maxMessageSize) { 111 this.maxMessageSize = maxMessageSize; 112 } 113 114 /** 115 * Return the {@link #setMaxMessageSize configured} message size limit. 116 * @since 5.1.11 117 */ 118 public int getMaxMessageSize() { 119 return this.maxMessageSize; 120 } 121 122 123 @Override 124 public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { 125 return Message.class.isAssignableFrom(elementType.toClass()) && supportsMimeType(mimeType); 126 } 127 128 @Override 129 public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType, 130 @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) { 131 132 MessageDecoderFunction decoderFunction = 133 new MessageDecoderFunction(elementType, this.maxMessageSize); 134 135 return Flux.from(inputStream) 136 .flatMapIterable(decoderFunction) 137 .doOnTerminate(decoderFunction::discard); 138 } 139 140 @Override 141 public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType, 142 @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) { 143 144 return DataBufferUtils.join(inputStream, this.maxMessageSize) 145 .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints)); 146 } 147 148 @Override 149 public Message decode(DataBuffer dataBuffer, ResolvableType targetType, 150 @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) throws DecodingException { 151 152 try { 153 Message.Builder builder = getMessageBuilder(targetType.toClass()); 154 ByteBuffer buffer = dataBuffer.asByteBuffer(); 155 builder.mergeFrom(CodedInputStream.newInstance(buffer), this.extensionRegistry); 156 return builder.build(); 157 } 158 catch (IOException ex) { 159 throw new DecodingException("I/O error while parsing input stream", ex); 160 } 161 catch (Exception ex) { 162 throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex); 163 } 164 finally { 165 DataBufferUtils.release(dataBuffer); 166 } 167 } 168 169 170 /** 171 * Create a new {@code Message.Builder} instance for the given class. 172 * <p>This method uses a ConcurrentHashMap for caching method lookups. 173 */ 174 private static Message.Builder getMessageBuilder(Class<?> clazz) throws Exception { 175 Method method = methodCache.get(clazz); 176 if (method == null) { 177 method = clazz.getMethod("newBuilder"); 178 methodCache.put(clazz, method); 179 } 180 return (Message.Builder) method.invoke(clazz); 181 } 182 183 @Override 184 public List<MimeType> getDecodableMimeTypes() { 185 return getMimeTypes(); 186 } 187 188 189 private class MessageDecoderFunction implements Function<DataBuffer, Iterable<? extends Message>> { 190 191 private final ResolvableType elementType; 192 193 private final int maxMessageSize; 194 195 @Nullable 196 private DataBuffer output; 197 198 private int messageBytesToRead; 199 200 private int offset; 201 202 203 public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) { 204 this.elementType = elementType; 205 this.maxMessageSize = maxMessageSize; 206 } 207 208 209 @Override 210 public Iterable<? extends Message> apply(DataBuffer input) { 211 try { 212 List<Message> messages = new ArrayList<>(); 213 int remainingBytesToRead; 214 int chunkBytesToRead; 215 216 do { 217 if (this.output == null) { 218 if (!readMessageSize(input)) { 219 return messages; 220 } 221 if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) { 222 throw new DataBufferLimitException( 223 "The number of bytes to read for message " + 224 "(" + this.messageBytesToRead + ") exceeds " + 225 "the configured limit (" + this.maxMessageSize + ")"); 226 } 227 this.output = input.factory().allocateBuffer(this.messageBytesToRead); 228 } 229 230 chunkBytesToRead = Math.min(this.messageBytesToRead, input.readableByteCount()); 231 remainingBytesToRead = input.readableByteCount() - chunkBytesToRead; 232 233 byte[] bytesToWrite = new byte[chunkBytesToRead]; 234 input.read(bytesToWrite, 0, chunkBytesToRead); 235 this.output.write(bytesToWrite); 236 this.messageBytesToRead -= chunkBytesToRead; 237 238 if (this.messageBytesToRead == 0) { 239 CodedInputStream stream = CodedInputStream.newInstance(this.output.asByteBuffer()); 240 DataBufferUtils.release(this.output); 241 this.output = null; 242 Message message = getMessageBuilder(this.elementType.toClass()) 243 .mergeFrom(stream, extensionRegistry) 244 .build(); 245 messages.add(message); 246 } 247 } while (remainingBytesToRead > 0); 248 return messages; 249 } 250 catch (DecodingException ex) { 251 throw ex; 252 } 253 catch (IOException ex) { 254 throw new DecodingException("I/O error while parsing input stream", ex); 255 } 256 catch (Exception ex) { 257 throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex); 258 } 259 finally { 260 DataBufferUtils.release(input); 261 } 262 } 263 264 /** 265 * Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and 266 * {@code offset} fields if needed to allow processing of upcoming chunks. 267 * Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)} 268 * 269 * @return {code true} when the message size is parsed successfully, {code false} when the message size is 270 * truncated 271 * @see <a href ="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a> 272 */ 273 private boolean readMessageSize(DataBuffer input) { 274 if (this.offset == 0) { 275 if (input.readableByteCount() == 0) { 276 return false; 277 } 278 int firstByte = input.read(); 279 if ((firstByte & 0x80) == 0) { 280 this.messageBytesToRead = firstByte; 281 return true; 282 } 283 this.messageBytesToRead = firstByte & 0x7f; 284 this.offset = 7; 285 } 286 287 if (this.offset < 32) { 288 for (; this.offset < 32; this.offset += 7) { 289 if (input.readableByteCount() == 0) { 290 return false; 291 } 292 final int b = input.read(); 293 this.messageBytesToRead |= (b & 0x7f) << offset; 294 if ((b & 0x80) == 0) { 295 this.offset = 0; 296 return true; 297 } 298 } 299 } 300 // Keep reading up to 64 bits. 301 for (; this.offset < 64; this.offset += 7) { 302 if (input.readableByteCount() == 0) { 303 return false; 304 } 305 final int b = input.read(); 306 if ((b & 0x80) == 0) { 307 this.offset = 0; 308 return true; 309 } 310 } 311 this.offset = 0; 312 throw new DecodingException("Cannot parse message size: malformed varint"); 313 } 314 315 public void discard() { 316 if (this.output != null) { 317 DataBufferUtils.release(this.output); 318 } 319 } 320 } 321 322}