001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.io.ByteArrayOutputStream;
020import java.nio.Buffer;
021import java.nio.ByteBuffer;
022import java.nio.charset.StandardCharsets;
023import java.util.ArrayList;
024import java.util.List;
025
026import org.apache.commons.logging.Log;
027
028import org.springframework.lang.Nullable;
029import org.springframework.messaging.Message;
030import org.springframework.messaging.simp.SimpLogging;
031import org.springframework.messaging.support.MessageBuilder;
032import org.springframework.messaging.support.MessageHeaderInitializer;
033import org.springframework.messaging.support.NativeMessageHeaderAccessor;
034import org.springframework.util.InvalidMimeTypeException;
035import org.springframework.util.MultiValueMap;
036import org.springframework.util.StreamUtils;
037
038/**
039 * Decodes one or more STOMP frames contained in a {@link ByteBuffer}.
040 *
041 * <p>An attempt is made to read all complete STOMP frames from the buffer, which
042 * could be zero, one, or more. If there is any left-over content, i.e. an incomplete
043 * STOMP frame, at the end the buffer is reset to point to the beginning of the
044 * partial content. The caller is then responsible for dealing with that
045 * incomplete content by buffering until there is more input available.
046 *
047 * @author Andy Wilkinson
048 * @author Rossen Stoyanchev
049 * @since 4.0
050 */
051public class StompDecoder {
052
053        static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
054
055        private static final Log logger = SimpLogging.forLogName(StompDecoder.class);
056
057        @Nullable
058        private MessageHeaderInitializer headerInitializer;
059
060
061        /**
062         * Configure a {@link MessageHeaderInitializer} to apply to the headers of
063         * {@link Message Messages} from decoded STOMP frames.
064         */
065        public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) {
066                this.headerInitializer = headerInitializer;
067        }
068
069        /**
070         * Return the configured {@code MessageHeaderInitializer}, if any.
071         */
072        @Nullable
073        public MessageHeaderInitializer getHeaderInitializer() {
074                return this.headerInitializer;
075        }
076
077
078        /**
079         * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
080         * list of {@link Message Messages}. If the input buffer contains partial STOMP frame
081         * content, or additional content with a partial STOMP frame, the buffer is
082         * reset and {@code null} is returned.
083         * @param byteBuffer the buffer to decode the STOMP frame from
084         * @return the decoded messages, or an empty list if none
085         * @throws StompConversionException raised in case of decoding issues
086         */
087        public List<Message<byte[]>> decode(ByteBuffer byteBuffer) {
088                return decode(byteBuffer, null);
089        }
090
091        /**
092         * Decodes one or more STOMP frames from the given {@code buffer} and returns
093         * a list of {@link Message Messages}.
094         * <p>If the given ByteBuffer contains only partial STOMP frame content and no
095         * complete STOMP frames, an empty list is returned, and the buffer is reset to
096         * to where it was.
097         * <p>If the buffer contains one ore more STOMP frames, those are returned and
098         * the buffer reset to point to the beginning of the unused partial content.
099         * <p>The output partialMessageHeaders map is used to store successfully parsed
100         * headers in case of partial content. The caller can then check if a
101         * "content-length" header was read, which helps to determine how much more
102         * content is needed before the next attempt to decode.
103         * @param byteBuffer the buffer to decode the STOMP frame from
104         * @param partialMessageHeaders an empty output map that will store the last
105         * successfully parsed partialMessageHeaders in case of partial message content
106         * in cases where the partial buffer ended with a partial STOMP frame
107         * @return the decoded messages, or an empty list if none
108         * @throws StompConversionException raised in case of decoding issues
109         */
110        public List<Message<byte[]>> decode(ByteBuffer byteBuffer,
111                        @Nullable MultiValueMap<String, String> partialMessageHeaders) {
112
113                List<Message<byte[]>> messages = new ArrayList<>();
114                while (byteBuffer.hasRemaining()) {
115                        Message<byte[]> message = decodeMessage(byteBuffer, partialMessageHeaders);
116                        if (message != null) {
117                                messages.add(message);
118                                skipEol(byteBuffer);
119                                if (!byteBuffer.hasRemaining()) {
120                                        break;
121                                }
122                        }
123                        else {
124                                break;
125                        }
126                }
127                return messages;
128        }
129
130        /**
131         * Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
132         */
133        @Nullable
134        private Message<byte[]> decodeMessage(ByteBuffer byteBuffer, @Nullable MultiValueMap<String, String> headers) {
135                Message<byte[]> decodedMessage = null;
136                skipEol(byteBuffer);
137
138                // Explicit mark/reset access via Buffer base type for compatibility
139                // with covariant return type on JDK 9's ByteBuffer...
140                Buffer buffer = byteBuffer;
141                buffer.mark();
142
143                String command = readCommand(byteBuffer);
144                if (command.length() > 0) {
145                        StompHeaderAccessor headerAccessor = null;
146                        byte[] payload = null;
147                        if (byteBuffer.remaining() > 0) {
148                                StompCommand stompCommand = StompCommand.valueOf(command);
149                                headerAccessor = StompHeaderAccessor.create(stompCommand);
150                                initHeaders(headerAccessor);
151                                readHeaders(byteBuffer, headerAccessor);
152                                payload = readPayload(byteBuffer, headerAccessor);
153                        }
154                        if (payload != null) {
155                                if (payload.length > 0) {
156                                        StompCommand stompCommand = headerAccessor.getCommand();
157                                        if (stompCommand != null && !stompCommand.isBodyAllowed()) {
158                                                throw new StompConversionException(stompCommand +
159                                                                " shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
160                                        }
161                                }
162                                headerAccessor.updateSimpMessageHeadersFromStompHeaders();
163                                headerAccessor.setLeaveMutable(true);
164                                decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
165                                if (logger.isTraceEnabled()) {
166                                        logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
167                                }
168                        }
169                        else {
170                                logger.trace("Incomplete frame, resetting input buffer...");
171                                if (headers != null && headerAccessor != null) {
172                                        String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
173                                        @SuppressWarnings("unchecked")
174                                        MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name);
175                                        if (map != null) {
176                                                headers.putAll(map);
177                                        }
178                                }
179                                buffer.reset();
180                        }
181                }
182                else {
183                        StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
184                        initHeaders(headerAccessor);
185                        headerAccessor.setLeaveMutable(true);
186                        decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
187                        if (logger.isTraceEnabled()) {
188                                logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
189                        }
190                }
191
192                return decodedMessage;
193        }
194
195        private void initHeaders(StompHeaderAccessor headerAccessor) {
196                MessageHeaderInitializer initializer = getHeaderInitializer();
197                if (initializer != null) {
198                        initializer.initHeaders(headerAccessor);
199                }
200        }
201
202        /**
203         * Skip one ore more EOL characters at the start of the given ByteBuffer.
204         * STOMP, section 2.1 says: "The NULL octet can be optionally followed by
205         * multiple EOLs."
206         */
207        protected void skipEol(ByteBuffer byteBuffer) {
208                while (true) {
209                        if (!tryConsumeEndOfLine(byteBuffer)) {
210                                break;
211                        }
212                }
213        }
214
215        private String readCommand(ByteBuffer byteBuffer) {
216                ByteArrayOutputStream command = new ByteArrayOutputStream(256);
217                while (byteBuffer.remaining() > 0 && !tryConsumeEndOfLine(byteBuffer)) {
218                        command.write(byteBuffer.get());
219                }
220                return StreamUtils.copyToString(command, StandardCharsets.UTF_8);
221        }
222
223        private void readHeaders(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
224                while (true) {
225                        ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
226                        boolean headerComplete = false;
227                        while (byteBuffer.hasRemaining()) {
228                                if (tryConsumeEndOfLine(byteBuffer)) {
229                                        headerComplete = true;
230                                        break;
231                                }
232                                headerStream.write(byteBuffer.get());
233                        }
234                        if (headerStream.size() > 0 && headerComplete) {
235                                String header = StreamUtils.copyToString(headerStream, StandardCharsets.UTF_8);
236                                int colonIndex = header.indexOf(':');
237                                if (colonIndex <= 0) {
238                                        if (byteBuffer.remaining() > 0) {
239                                                throw new StompConversionException("Illegal header: '" + header +
240                                                                "'. A header must be of the form <name>:[<value>].");
241                                        }
242                                }
243                                else {
244                                        String headerName = unescape(header.substring(0, colonIndex));
245                                        String headerValue = unescape(header.substring(colonIndex + 1));
246                                        try {
247                                                headerAccessor.addNativeHeader(headerName, headerValue);
248                                        }
249                                        catch (InvalidMimeTypeException ex) {
250                                                if (byteBuffer.remaining() > 0) {
251                                                        throw ex;
252                                                }
253                                        }
254                                }
255                        }
256                        else {
257                                break;
258                        }
259                }
260        }
261
262        /**
263         * See STOMP Spec 1.2:
264         * <a href="https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
265         */
266        private String unescape(String inString) {
267                StringBuilder sb = new StringBuilder(inString.length());
268                int pos = 0;  // position in the old string
269                int index = inString.indexOf('\\');
270
271                while (index >= 0) {
272                        sb.append(inString, pos, index);
273                        if (index + 1 >= inString.length()) {
274                                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
275                        }
276                        char c = inString.charAt(index + 1);
277                        if (c == 'r') {
278                                sb.append('\r');
279                        }
280                        else if (c == 'n') {
281                                sb.append('\n');
282                        }
283                        else if (c == 'c') {
284                                sb.append(':');
285                        }
286                        else if (c == '\\') {
287                                sb.append('\\');
288                        }
289                        else {
290                                // should never happen
291                                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
292                        }
293                        pos = index + 2;
294                        index = inString.indexOf('\\', pos);
295                }
296
297                sb.append(inString.substring(pos));
298                return sb.toString();
299        }
300
301        @Nullable
302        private byte[] readPayload(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
303                Integer contentLength;
304                try {
305                        contentLength = headerAccessor.getContentLength();
306                }
307                catch (NumberFormatException ex) {
308                        if (logger.isDebugEnabled()) {
309                                logger.debug("Ignoring invalid content-length: '" + headerAccessor);
310                        }
311                        contentLength = null;
312                }
313
314                if (contentLength != null && contentLength >= 0) {
315                        if (byteBuffer.remaining() > contentLength) {
316                                byte[] payload = new byte[contentLength];
317                                byteBuffer.get(payload);
318                                if (byteBuffer.get() != 0) {
319                                        throw new StompConversionException("Frame must be terminated with a null octet");
320                                }
321                                return payload;
322                        }
323                        else {
324                                return null;
325                        }
326                }
327                else {
328                        ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
329                        while (byteBuffer.remaining() > 0) {
330                                byte b = byteBuffer.get();
331                                if (b == 0) {
332                                        return payload.toByteArray();
333                                }
334                                else {
335                                        payload.write(b);
336                                }
337                        }
338                }
339                return null;
340        }
341
342        /**
343         * Try to read an EOL incrementing the buffer position if successful.
344         * @return whether an EOL was consumed
345         */
346        private boolean tryConsumeEndOfLine(ByteBuffer byteBuffer) {
347                if (byteBuffer.remaining() > 0) {
348                        byte b = byteBuffer.get();
349                        if (b == '\n') {
350                                return true;
351                        }
352                        else if (b == '\r') {
353                                if (byteBuffer.remaining() > 0 && byteBuffer.get() == '\n') {
354                                        return true;
355                                }
356                                else {
357                                        throw new StompConversionException("'\\r' must be followed by '\\n'");
358                                }
359                        }
360                        // Explicit cast for compatibility with covariant return type on JDK 9's ByteBuffer
361                        ((Buffer) byteBuffer).position(byteBuffer.position() - 1);
362                }
363                return false;
364        }
365
366}