001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.nio.charset.StandardCharsets;
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.LinkedHashMap;
023import java.util.List;
024import java.util.Map;
025import java.util.Map.Entry;
026import java.util.concurrent.ConcurrentHashMap;
027
028import org.apache.commons.logging.Log;
029
030import org.springframework.lang.Nullable;
031import org.springframework.messaging.Message;
032import org.springframework.messaging.simp.SimpLogging;
033import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
034import org.springframework.messaging.simp.SimpMessageType;
035import org.springframework.messaging.support.NativeMessageHeaderAccessor;
036import org.springframework.util.Assert;
037
038/**
039 * An encoder for STOMP frames.
040 *
041 * @author Andy Wilkinson
042 * @author Rossen Stoyanchev
043 * @since 4.0
044 * @see StompDecoder
045 */
046public class StompEncoder  {
047
048        private static final Byte LINE_FEED_BYTE = '\n';
049
050        private static final Byte COLON_BYTE = ':';
051
052        private static final Log logger = SimpLogging.forLogName(StompEncoder.class);
053
054        private static final int HEADER_KEY_CACHE_LIMIT = 32;
055
056
057        private final Map<String, byte[]> headerKeyAccessCache = new ConcurrentHashMap<>(HEADER_KEY_CACHE_LIMIT);
058
059        @SuppressWarnings("serial")
060        private final Map<String, byte[]> headerKeyUpdateCache =
061                        new LinkedHashMap<String, byte[]>(HEADER_KEY_CACHE_LIMIT, 0.75f, true) {
062                                @Override
063                                protected boolean removeEldestEntry(Map.Entry<String, byte[]> eldest) {
064                                        if (size() > HEADER_KEY_CACHE_LIMIT) {
065                                                headerKeyAccessCache.remove(eldest.getKey());
066                                                return true;
067                                        }
068                                        else {
069                                                return false;
070                                        }
071                                }
072                        };
073
074
075        /**
076         * Encodes the given STOMP {@code message} into a {@code byte[]}.
077         * @param message the message to encode
078         * @return the encoded message
079         */
080        public byte[] encode(Message<byte[]> message) {
081                return encode(message.getHeaders(), message.getPayload());
082        }
083
084        /**
085         * Encodes the given payload and headers into a {@code byte[]}.
086         * @param headers the headers
087         * @param payload the payload
088         * @return the encoded message
089         */
090        public byte[] encode(Map<String, Object> headers, byte[] payload) {
091                Assert.notNull(headers, "'headers' is required");
092                Assert.notNull(payload, "'payload' is required");
093
094                if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) {
095                        logger.trace("Encoding heartbeat");
096                        return StompDecoder.HEARTBEAT_PAYLOAD;
097                }
098
099                StompCommand command = StompHeaderAccessor.getCommand(headers);
100                if (command == null) {
101                        throw new IllegalStateException("Missing STOMP command: " + headers);
102                }
103
104                Result result = new DefaultResult();
105                result.add(command.toString().getBytes(StandardCharsets.UTF_8));
106                result.add(LINE_FEED_BYTE);
107                writeHeaders(command, headers, payload, result);
108                result.add(LINE_FEED_BYTE);
109                result.add(payload);
110                result.add((byte) 0);
111                return result.toByteArray();
112        }
113
114        private void writeHeaders(
115                        StompCommand command, Map<String, Object> headers, byte[] payload, Result result) {
116
117                @SuppressWarnings("unchecked")
118                Map<String,List<String>> nativeHeaders =
119                                (Map<String, List<String>>) headers.get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
120
121                if (logger.isTraceEnabled()) {
122                        logger.trace("Encoding STOMP " + command + ", headers=" + nativeHeaders);
123                }
124
125                if (nativeHeaders == null) {
126                        return;
127                }
128
129                boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.STOMP
130                                && command != StompCommand.CONNECTED);
131
132                for (Entry<String, List<String>> entry : nativeHeaders.entrySet()) {
133                        if (command.requiresContentLength() && "content-length".equals(entry.getKey())) {
134                                continue;
135                        }
136
137                        List<String> values = entry.getValue();
138                        if ((StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) &&
139                                        StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) {
140                                values = Collections.singletonList(StompHeaderAccessor.getPasscode(headers));
141                        }
142
143                        byte[] encodedKey = encodeHeaderKey(entry.getKey(), shouldEscape);
144                        for (String value : values) {
145                                result.add(encodedKey);
146                                result.add(COLON_BYTE);
147                                result.add(encodeHeaderValue(value, shouldEscape));
148                                result.add(LINE_FEED_BYTE);
149                        }
150                }
151
152                if (command.requiresContentLength()) {
153                        int contentLength = payload.length;
154                        result.add("content-length:".getBytes(StandardCharsets.UTF_8));
155                        result.add(Integer.toString(contentLength).getBytes(StandardCharsets.UTF_8));
156                        result.add(LINE_FEED_BYTE);
157                }
158        }
159
160        private byte[] encodeHeaderKey(String input, boolean escape) {
161                String inputToUse = (escape ? escape(input) : input);
162                if (this.headerKeyAccessCache.containsKey(inputToUse)) {
163                        return this.headerKeyAccessCache.get(inputToUse);
164                }
165                synchronized (this.headerKeyUpdateCache) {
166                        byte[] bytes = this.headerKeyUpdateCache.get(inputToUse);
167                        if (bytes == null) {
168                                bytes = inputToUse.getBytes(StandardCharsets.UTF_8);
169                                this.headerKeyAccessCache.put(inputToUse, bytes);
170                                this.headerKeyUpdateCache.put(inputToUse, bytes);
171                        }
172                        return bytes;
173                }
174        }
175
176        private byte[] encodeHeaderValue(String input, boolean escape) {
177                String inputToUse = (escape ? escape(input) : input);
178                return inputToUse.getBytes(StandardCharsets.UTF_8);
179        }
180
181        /**
182         * See STOMP Spec 1.2:
183         * <a href="https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
184         */
185        private String escape(String inString) {
186                StringBuilder sb = null;
187                for (int i = 0; i < inString.length(); i++) {
188                        char c = inString.charAt(i);
189                        if (c == '\\') {
190                                sb = getStringBuilder(sb, inString, i);
191                                sb.append("\\\\");
192                        }
193                        else if (c == ':') {
194                                sb = getStringBuilder(sb, inString, i);
195                                sb.append("\\c");
196                        }
197                        else if (c == '\n') {
198                                sb = getStringBuilder(sb, inString, i);
199                                sb.append("\\n");
200                        }
201                        else if (c == '\r') {
202                                sb = getStringBuilder(sb, inString, i);
203                                sb.append("\\r");
204                        }
205                        else if (sb != null){
206                                sb.append(c);
207                        }
208                }
209                return (sb != null ? sb.toString() : inString);
210        }
211
212        private StringBuilder getStringBuilder(@Nullable StringBuilder sb, String inString, int i) {
213                if (sb == null) {
214                        sb = new StringBuilder(inString.length());
215                        sb.append(inString, 0, i);
216                }
217                return sb;
218        }
219
220
221        /**
222         * Accumulates byte content and returns an aggregated byte[] at the end.
223         */
224        private interface Result {
225
226                void add(byte[] bytes);
227
228                void add(byte b);
229
230                byte[] toByteArray();
231        }
232
233
234        @SuppressWarnings("serial")
235        private static class DefaultResult extends ArrayList<Object> implements Result {
236
237                private int size;
238
239                public void add(byte[] bytes) {
240                        this.size += bytes.length;
241                        super.add(bytes);
242                }
243
244                public void add(byte b) {
245                        this.size++;
246                        super.add(b);
247                }
248
249                public byte[] toByteArray() {
250                        byte[] result = new byte[this.size];
251                        int position = 0;
252                        for (Object o : this) {
253                                if (o instanceof byte[]) {
254                                        byte[] src = (byte[]) o;
255                                        System.arraycopy(src, 0, result, position, src.length);
256                                        position += src.length;
257                                }
258                                else {
259                                        result[position++] = (Byte) o;
260                                }
261                        }
262                        return result;
263                }
264        }
265
266}