001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.io.ByteArrayOutputStream;
020import java.nio.ByteBuffer;
021import java.nio.charset.Charset;
022import java.util.ArrayList;
023import java.util.List;
024
025import org.apache.commons.logging.Log;
026import org.apache.commons.logging.LogFactory;
027
028import org.springframework.messaging.Message;
029import org.springframework.messaging.support.MessageBuilder;
030import org.springframework.messaging.support.MessageHeaderInitializer;
031import org.springframework.messaging.support.NativeMessageHeaderAccessor;
032import org.springframework.util.InvalidMimeTypeException;
033import org.springframework.util.MultiValueMap;
034
035/**
036 * Decodes one or more STOMP frames contained in a {@link ByteBuffer}.
037 *
038 * <p>An attempt is made to read all complete STOMP frames from the buffer, which
039 * could be zero, one, or more. If there is any left-over content, i.e. an incomplete
040 * STOMP frame, at the end the buffer is reset to point to the beginning of the
041 * partial content. The caller is then responsible for dealing with that
042 * incomplete content by buffering until there is more input available.
043 *
044 * @author Andy Wilkinson
045 * @author Rossen Stoyanchev
046 * @since 4.0
047 */
048public class StompDecoder {
049
050        static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
051
052        static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
053
054        private static final Log logger = LogFactory.getLog(StompDecoder.class);
055
056        private MessageHeaderInitializer headerInitializer;
057
058
059        /**
060         * Configure a {@link MessageHeaderInitializer} to apply to the headers of
061         * {@link Message}s from decoded STOMP frames.
062         */
063        public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) {
064                this.headerInitializer = headerInitializer;
065        }
066
067        /**
068         * Return the configured {@code MessageHeaderInitializer}, if any.
069         */
070        public MessageHeaderInitializer getHeaderInitializer() {
071                return this.headerInitializer;
072        }
073
074
075        /**
076         * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
077         * list of {@link Message}s. If the input buffer contains partial STOMP frame
078         * content, or additional content with a partial STOMP frame, the buffer is
079         * reset and {@code null} is returned.
080         * @param buffer the buffer to decode the STOMP frame from
081         * @return the decoded messages, or an empty list if none
082         * @throws StompConversionException raised in case of decoding issues
083         */
084        public List<Message<byte[]>> decode(ByteBuffer buffer) {
085                return decode(buffer, null);
086        }
087
088        /**
089         * Decodes one or more STOMP frames from the given {@code buffer} and returns
090         * a list of {@link Message}s.
091         * <p>If the given ByteBuffer contains only partial STOMP frame content and no
092         * complete STOMP frames, an empty list is returned, and the buffer is reset to
093         * to where it was.
094         * <p>If the buffer contains one ore more STOMP frames, those are returned and
095         * the buffer reset to point to the beginning of the unused partial content.
096         * <p>The output partialMessageHeaders map is used to store successfully parsed
097         * headers in case of partial content. The caller can then check if a
098         * "content-length" header was read, which helps to determine how much more
099         * content is needed before the next attempt to decode.
100         * @param buffer the buffer to decode the STOMP frame from
101         * @param partialMessageHeaders an empty output map that will store the last
102         * successfully parsed partialMessageHeaders in case of partial message content
103         * in cases where the partial buffer ended with a partial STOMP frame
104         * @return the decoded messages, or an empty list if none
105         * @throws StompConversionException raised in case of decoding issues
106         */
107        public List<Message<byte[]>> decode(ByteBuffer buffer, MultiValueMap<String, String> partialMessageHeaders) {
108                List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
109                while (buffer.hasRemaining()) {
110                        Message<byte[]> message = decodeMessage(buffer, partialMessageHeaders);
111                        if (message != null) {
112                                messages.add(message);
113                        }
114                        else {
115                                break;
116                        }
117                }
118                return messages;
119        }
120
121        /**
122         * Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
123         */
124        private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {
125                Message<byte[]> decodedMessage = null;
126                skipLeadingEol(buffer);
127                buffer.mark();
128
129                String command = readCommand(buffer);
130                if (command.length() > 0) {
131                        StompHeaderAccessor headerAccessor = null;
132                        byte[] payload = null;
133                        if (buffer.remaining() > 0) {
134                                StompCommand stompCommand = StompCommand.valueOf(command);
135                                headerAccessor = StompHeaderAccessor.create(stompCommand);
136                                initHeaders(headerAccessor);
137                                readHeaders(buffer, headerAccessor);
138                                payload = readPayload(buffer, headerAccessor);
139                        }
140                        if (payload != null) {
141                                if (payload.length > 0) {
142                                        StompCommand stompCommand = headerAccessor.getCommand();
143                                        if (stompCommand != null && !stompCommand.isBodyAllowed()) {
144                                                throw new StompConversionException(stompCommand +
145                                                                " shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
146                                        }
147                                }
148                                headerAccessor.updateSimpMessageHeadersFromStompHeaders();
149                                headerAccessor.setLeaveMutable(true);
150                                decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
151                                if (logger.isTraceEnabled()) {
152                                        logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
153                                }
154                        }
155                        else {
156                                logger.trace("Incomplete frame, resetting input buffer...");
157                                if (headers != null && headerAccessor != null) {
158                                        String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
159                                        @SuppressWarnings("unchecked")
160                                        MultiValueMap<String, String> map = (MultiValueMap<String, String>) headerAccessor.getHeader(name);
161                                        if (map != null) {
162                                                headers.putAll(map);
163                                        }
164                                }
165                                buffer.reset();
166                        }
167                }
168                else {
169                        StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
170                        initHeaders(headerAccessor);
171                        headerAccessor.setLeaveMutable(true);
172                        decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
173                        if (logger.isTraceEnabled()) {
174                                logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
175                        }
176                }
177
178                return decodedMessage;
179        }
180
181        private void initHeaders(StompHeaderAccessor headerAccessor) {
182                MessageHeaderInitializer initializer = getHeaderInitializer();
183                if (initializer != null) {
184                        initializer.initHeaders(headerAccessor);
185                }
186        }
187
188        /**
189         * Skip one ore more EOL characters at the start of the given ByteBuffer.
190         * Those are STOMP heartbeat frames.
191         */
192        protected void skipLeadingEol(ByteBuffer buffer) {
193                while (true) {
194                        if (!tryConsumeEndOfLine(buffer)) {
195                                break;
196                        }
197                }
198        }
199
200        private String readCommand(ByteBuffer buffer) {
201                ByteArrayOutputStream command = new ByteArrayOutputStream(256);
202                while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
203                        command.write(buffer.get());
204                }
205                return new String(command.toByteArray(), UTF8_CHARSET);
206        }
207
208        private void readHeaders(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
209                while (true) {
210                        ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
211                        boolean headerComplete = false;
212                        while (buffer.hasRemaining()) {
213                                if (tryConsumeEndOfLine(buffer)) {
214                                        headerComplete = true;
215                                        break;
216                                }
217                                headerStream.write(buffer.get());
218                        }
219                        if (headerStream.size() > 0 && headerComplete) {
220                                String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
221                                int colonIndex = header.indexOf(':');
222                                if (colonIndex <= 0) {
223                                        if (buffer.remaining() > 0) {
224                                                throw new StompConversionException("Illegal header: '" + header +
225                                                                "'. A header must be of the form <name>:[<value>].");
226                                        }
227                                }
228                                else {
229                                        String headerName = unescape(header.substring(0, colonIndex));
230                                        String headerValue = unescape(header.substring(colonIndex + 1));
231                                        try {
232                                                headerAccessor.addNativeHeader(headerName, headerValue);
233                                        }
234                                        catch (InvalidMimeTypeException ex) {
235                                                if (buffer.remaining() > 0) {
236                                                        throw ex;
237                                                }
238                                        }
239                                }
240                        }
241                        else {
242                                break;
243                        }
244                }
245        }
246
247        /**
248         * See STOMP Spec 1.2:
249         * <a href="https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
250         */
251        private String unescape(String inString) {
252                StringBuilder sb = new StringBuilder(inString.length());
253                int pos = 0;  // position in the old string
254                int index = inString.indexOf('\\');
255
256                while (index >= 0) {
257                        sb.append(inString, pos, index);
258                        if (index + 1 >= inString.length()) {
259                                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
260                        }
261                        char c = inString.charAt(index + 1);
262                        if (c == 'r') {
263                                sb.append('\r');
264                        }
265                        else if (c == 'n') {
266                                sb.append('\n');
267                        }
268                        else if (c == 'c') {
269                                sb.append(':');
270                        }
271                        else if (c == '\\') {
272                                sb.append('\\');
273                        }
274                        else {
275                                // should never happen
276                                throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
277                        }
278                        pos = index + 2;
279                        index = inString.indexOf('\\', pos);
280                }
281
282                sb.append(inString.substring(pos));
283                return sb.toString();
284        }
285
286        private byte[] readPayload(ByteBuffer buffer, StompHeaderAccessor headerAccessor) {
287                Integer contentLength;
288                try {
289                        contentLength = headerAccessor.getContentLength();
290                }
291                catch (NumberFormatException ex) {
292                        if (logger.isWarnEnabled()) {
293                                logger.warn("Ignoring invalid content-length: '" + headerAccessor);
294                        }
295                        contentLength = null;
296                }
297
298                if (contentLength != null && contentLength >= 0) {
299                        if (buffer.remaining() > contentLength) {
300                                byte[] payload = new byte[contentLength];
301                                buffer.get(payload);
302                                if (buffer.get() != 0) {
303                                        throw new StompConversionException("Frame must be terminated with a null octet");
304                                }
305                                return payload;
306                        }
307                        else {
308                                return null;
309                        }
310                }
311                else {
312                        ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
313                        while (buffer.remaining() > 0) {
314                                byte b = buffer.get();
315                                if (b == 0) {
316                                        return payload.toByteArray();
317                                }
318                                else {
319                                        payload.write(b);
320                                }
321                        }
322                }
323                return null;
324        }
325
326        /**
327         * Try to read an EOL incrementing the buffer position if successful.
328         * @return whether an EOL was consumed
329         */
330        private boolean tryConsumeEndOfLine(ByteBuffer buffer) {
331                if (buffer.remaining() > 0) {
332                        byte b = buffer.get();
333                        if (b == '\n') {
334                                return true;
335                        }
336                        else if (b == '\r') {
337                                if (buffer.remaining() > 0 && buffer.get() == '\n') {
338                                        return true;
339                                }
340                                else {
341                                        throw new StompConversionException("'\\r' must be followed by '\\n'");
342                                }
343                        }
344                        buffer.position(buffer.position() - 1);
345                }
346                return false;
347        }
348
349}