001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.io.ByteArrayOutputStream;
020import java.io.DataOutputStream;
021import java.io.IOException;
022import java.util.Collections;
023import java.util.LinkedHashMap;
024import java.util.List;
025import java.util.Map;
026import java.util.Map.Entry;
027import java.util.concurrent.ConcurrentHashMap;
028
029import org.apache.commons.logging.Log;
030import org.apache.commons.logging.LogFactory;
031
032import org.springframework.messaging.Message;
033import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
034import org.springframework.messaging.simp.SimpMessageType;
035import org.springframework.messaging.support.NativeMessageHeaderAccessor;
036import org.springframework.util.Assert;
037
038/**
039 * An encoder for STOMP frames.
040 *
041 * @author Andy Wilkinson
042 * @author Rossen Stoyanchev
043 * @since 4.0
044 * @see StompDecoder
045 */
046public class StompEncoder  {
047
048        private static final byte LF = '\n';
049
050        private static final byte COLON = ':';
051
052        private static final Log logger = LogFactory.getLog(StompEncoder.class);
053
054        private static final int HEADER_KEY_CACHE_LIMIT = 32;
055
056
057        private final Map<String, byte[]> headerKeyAccessCache =
058                        new ConcurrentHashMap<String, byte[]>(HEADER_KEY_CACHE_LIMIT);
059
060        @SuppressWarnings("serial")
061        private final Map<String, byte[]> headerKeyUpdateCache =
062                        new LinkedHashMap<String, byte[]>(HEADER_KEY_CACHE_LIMIT, 0.75f, true) {
063                                @Override
064                                protected boolean removeEldestEntry(Map.Entry<String, byte[]> eldest) {
065                                        if (size() > HEADER_KEY_CACHE_LIMIT) {
066                                                headerKeyAccessCache.remove(eldest.getKey());
067                                                return true;
068                                        }
069                                        else {
070                                                return false;
071                                        }
072                                }
073                        };
074
075
076        /**
077         * Encodes the given STOMP {@code message} into a {@code byte[]}
078         * @param message the message to encode
079         * @return the encoded message
080         */
081        public byte[] encode(Message<byte[]> message) {
082                return encode(message.getHeaders(), message.getPayload());
083        }
084
085        /**
086         * Encodes the given payload and headers into a {@code byte[]}.
087         * @param headers the headers
088         * @param payload the payload
089         * @return the encoded message
090         */
091        public byte[] encode(Map<String, Object> headers, byte[] payload) {
092                Assert.notNull(headers, "'headers' is required");
093                Assert.notNull(payload, "'payload' is required");
094
095                try {
096                        ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + payload.length);
097                        DataOutputStream output = new DataOutputStream(baos);
098
099                        if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) {
100                                logger.trace("Encoding heartbeat");
101                                output.write(StompDecoder.HEARTBEAT_PAYLOAD);
102                        }
103
104                        else {
105                                StompCommand command = StompHeaderAccessor.getCommand(headers);
106                                if (command == null) {
107                                        throw new IllegalStateException("Missing STOMP command: " + headers);
108                                }
109
110                                output.write(command.toString().getBytes(StompDecoder.UTF8_CHARSET));
111                                output.write(LF);
112                                writeHeaders(command, headers, payload, output);
113                                output.write(LF);
114                                writeBody(payload, output);
115                                output.write((byte) 0);
116                        }
117
118                        return baos.toByteArray();
119                }
120                catch (IOException ex) {
121                        throw new StompConversionException("Failed to encode STOMP frame, headers=" + headers,  ex);
122                }
123        }
124
125        private void writeHeaders(StompCommand command, Map<String, Object> headers, byte[] payload,
126                        DataOutputStream output) throws IOException {
127
128                @SuppressWarnings("unchecked")
129                Map<String,List<String>> nativeHeaders =
130                                (Map<String, List<String>>) headers.get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
131
132                if (logger.isTraceEnabled()) {
133                        logger.trace("Encoding STOMP " + command + ", headers=" + nativeHeaders);
134                }
135
136                if (nativeHeaders == null) {
137                        return;
138                }
139
140                boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.CONNECTED);
141
142                for (Entry<String, List<String>> entry : nativeHeaders.entrySet()) {
143                        if (command.requiresContentLength() && "content-length".equals(entry.getKey())) {
144                                continue;
145                        }
146
147                        List<String> values = entry.getValue();
148                        if (StompCommand.CONNECT.equals(command) &&
149                                        StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) {
150                                values = Collections.singletonList(StompHeaderAccessor.getPasscode(headers));
151                        }
152
153                        byte[] encodedKey = encodeHeaderKey(entry.getKey(), shouldEscape);
154                        for (String value : values) {
155                                output.write(encodedKey);
156                                output.write(COLON);
157                                output.write(encodeHeaderValue(value, shouldEscape));
158                                output.write(LF);
159                        }
160                }
161
162                if (command.requiresContentLength()) {
163                        int contentLength = payload.length;
164                        output.write("content-length:".getBytes(StompDecoder.UTF8_CHARSET));
165                        output.write(Integer.toString(contentLength).getBytes(StompDecoder.UTF8_CHARSET));
166                        output.write(LF);
167                }
168        }
169
170        private byte[] encodeHeaderKey(String input, boolean escape) {
171                String inputToUse = (escape ? escape(input) : input);
172                if (this.headerKeyAccessCache.containsKey(inputToUse)) {
173                        return this.headerKeyAccessCache.get(inputToUse);
174                }
175                synchronized (this.headerKeyUpdateCache) {
176                        byte[] bytes = this.headerKeyUpdateCache.get(inputToUse);
177                        if (bytes == null) {
178                                bytes = inputToUse.getBytes(StompDecoder.UTF8_CHARSET);
179                                this.headerKeyAccessCache.put(inputToUse, bytes);
180                                this.headerKeyUpdateCache.put(inputToUse, bytes);
181                        }
182                        return bytes;
183                }
184        }
185
186        private byte[] encodeHeaderValue(String input, boolean escape) {
187                String inputToUse = (escape ? escape(input) : input);
188                return inputToUse.getBytes(StompDecoder.UTF8_CHARSET);
189        }
190
191        /**
192         * See STOMP Spec 1.2:
193         * <a href="https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>.
194         */
195        private String escape(String inString) {
196                StringBuilder sb = null;
197                for (int i = 0; i < inString.length(); i++) {
198                        char c = inString.charAt(i);
199                        if (c == '\\') {
200                                sb = getStringBuilder(sb, inString, i);
201                                sb.append("\\\\");
202                        }
203                        else if (c == ':') {
204                                sb = getStringBuilder(sb, inString, i);
205                                sb.append("\\c");
206                        }
207                        else if (c == '\n') {
208                                sb = getStringBuilder(sb, inString, i);
209                                sb.append("\\n");
210                        }
211                        else if (c == '\r') {
212                                sb = getStringBuilder(sb, inString, i);
213                                sb.append("\\r");
214                        }
215                        else if (sb != null){
216                                sb.append(c);
217                        }
218                }
219                return (sb != null ? sb.toString() : inString);
220        }
221
222        private StringBuilder getStringBuilder(StringBuilder sb, String inString, int i) {
223                if (sb == null) {
224                        sb = new StringBuilder(inString.length());
225                        sb.append(inString, 0, i);
226                }
227                return sb;
228        }
229
230        private void writeBody(byte[] payload, DataOutputStream output) throws IOException {
231                output.write(payload);
232        }
233
234}