001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.messaging.simp.stomp; 018 019import java.nio.charset.StandardCharsets; 020import java.util.ArrayList; 021import java.util.Collections; 022import java.util.LinkedHashMap; 023import java.util.List; 024import java.util.Map; 025import java.util.Map.Entry; 026import java.util.concurrent.ConcurrentHashMap; 027 028import org.apache.commons.logging.Log; 029 030import org.springframework.lang.Nullable; 031import org.springframework.messaging.Message; 032import org.springframework.messaging.simp.SimpLogging; 033import org.springframework.messaging.simp.SimpMessageHeaderAccessor; 034import org.springframework.messaging.simp.SimpMessageType; 035import org.springframework.messaging.support.NativeMessageHeaderAccessor; 036import org.springframework.util.Assert; 037 038/** 039 * An encoder for STOMP frames. 040 * 041 * @author Andy Wilkinson 042 * @author Rossen Stoyanchev 043 * @since 4.0 044 * @see StompDecoder 045 */ 046public class StompEncoder { 047 048 private static final Byte LINE_FEED_BYTE = '\n'; 049 050 private static final Byte COLON_BYTE = ':'; 051 052 private static final Log logger = SimpLogging.forLogName(StompEncoder.class); 053 054 private static final int HEADER_KEY_CACHE_LIMIT = 32; 055 056 057 private final Map<String, byte[]> headerKeyAccessCache = new ConcurrentHashMap<>(HEADER_KEY_CACHE_LIMIT); 058 059 @SuppressWarnings("serial") 060 private final Map<String, byte[]> headerKeyUpdateCache = 061 new LinkedHashMap<String, byte[]>(HEADER_KEY_CACHE_LIMIT, 0.75f, true) { 062 @Override 063 protected boolean removeEldestEntry(Map.Entry<String, byte[]> eldest) { 064 if (size() > HEADER_KEY_CACHE_LIMIT) { 065 headerKeyAccessCache.remove(eldest.getKey()); 066 return true; 067 } 068 else { 069 return false; 070 } 071 } 072 }; 073 074 075 /** 076 * Encodes the given STOMP {@code message} into a {@code byte[]}. 077 * @param message the message to encode 078 * @return the encoded message 079 */ 080 public byte[] encode(Message<byte[]> message) { 081 return encode(message.getHeaders(), message.getPayload()); 082 } 083 084 /** 085 * Encodes the given payload and headers into a {@code byte[]}. 086 * @param headers the headers 087 * @param payload the payload 088 * @return the encoded message 089 */ 090 public byte[] encode(Map<String, Object> headers, byte[] payload) { 091 Assert.notNull(headers, "'headers' is required"); 092 Assert.notNull(payload, "'payload' is required"); 093 094 if (SimpMessageType.HEARTBEAT.equals(SimpMessageHeaderAccessor.getMessageType(headers))) { 095 logger.trace("Encoding heartbeat"); 096 return StompDecoder.HEARTBEAT_PAYLOAD; 097 } 098 099 StompCommand command = StompHeaderAccessor.getCommand(headers); 100 if (command == null) { 101 throw new IllegalStateException("Missing STOMP command: " + headers); 102 } 103 104 Result result = new DefaultResult(); 105 result.add(command.toString().getBytes(StandardCharsets.UTF_8)); 106 result.add(LINE_FEED_BYTE); 107 writeHeaders(command, headers, payload, result); 108 result.add(LINE_FEED_BYTE); 109 result.add(payload); 110 result.add((byte) 0); 111 return result.toByteArray(); 112 } 113 114 private void writeHeaders( 115 StompCommand command, Map<String, Object> headers, byte[] payload, Result result) { 116 117 @SuppressWarnings("unchecked") 118 Map<String,List<String>> nativeHeaders = 119 (Map<String, List<String>>) headers.get(NativeMessageHeaderAccessor.NATIVE_HEADERS); 120 121 if (logger.isTraceEnabled()) { 122 logger.trace("Encoding STOMP " + command + ", headers=" + nativeHeaders); 123 } 124 125 if (nativeHeaders == null) { 126 return; 127 } 128 129 boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.STOMP 130 && command != StompCommand.CONNECTED); 131 132 for (Entry<String, List<String>> entry : nativeHeaders.entrySet()) { 133 if (command.requiresContentLength() && "content-length".equals(entry.getKey())) { 134 continue; 135 } 136 137 List<String> values = entry.getValue(); 138 if ((StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) && 139 StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) { 140 values = Collections.singletonList(StompHeaderAccessor.getPasscode(headers)); 141 } 142 143 byte[] encodedKey = encodeHeaderKey(entry.getKey(), shouldEscape); 144 for (String value : values) { 145 result.add(encodedKey); 146 result.add(COLON_BYTE); 147 result.add(encodeHeaderValue(value, shouldEscape)); 148 result.add(LINE_FEED_BYTE); 149 } 150 } 151 152 if (command.requiresContentLength()) { 153 int contentLength = payload.length; 154 result.add("content-length:".getBytes(StandardCharsets.UTF_8)); 155 result.add(Integer.toString(contentLength).getBytes(StandardCharsets.UTF_8)); 156 result.add(LINE_FEED_BYTE); 157 } 158 } 159 160 private byte[] encodeHeaderKey(String input, boolean escape) { 161 String inputToUse = (escape ? escape(input) : input); 162 if (this.headerKeyAccessCache.containsKey(inputToUse)) { 163 return this.headerKeyAccessCache.get(inputToUse); 164 } 165 synchronized (this.headerKeyUpdateCache) { 166 byte[] bytes = this.headerKeyUpdateCache.get(inputToUse); 167 if (bytes == null) { 168 bytes = inputToUse.getBytes(StandardCharsets.UTF_8); 169 this.headerKeyAccessCache.put(inputToUse, bytes); 170 this.headerKeyUpdateCache.put(inputToUse, bytes); 171 } 172 return bytes; 173 } 174 } 175 176 private byte[] encodeHeaderValue(String input, boolean escape) { 177 String inputToUse = (escape ? escape(input) : input); 178 return inputToUse.getBytes(StandardCharsets.UTF_8); 179 } 180 181 /** 182 * See STOMP Spec 1.2: 183 * <a href="https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding">"Value Encoding"</a>. 184 */ 185 private String escape(String inString) { 186 StringBuilder sb = null; 187 for (int i = 0; i < inString.length(); i++) { 188 char c = inString.charAt(i); 189 if (c == '\\') { 190 sb = getStringBuilder(sb, inString, i); 191 sb.append("\\\\"); 192 } 193 else if (c == ':') { 194 sb = getStringBuilder(sb, inString, i); 195 sb.append("\\c"); 196 } 197 else if (c == '\n') { 198 sb = getStringBuilder(sb, inString, i); 199 sb.append("\\n"); 200 } 201 else if (c == '\r') { 202 sb = getStringBuilder(sb, inString, i); 203 sb.append("\\r"); 204 } 205 else if (sb != null){ 206 sb.append(c); 207 } 208 } 209 return (sb != null ? sb.toString() : inString); 210 } 211 212 private StringBuilder getStringBuilder(@Nullable StringBuilder sb, String inString, int i) { 213 if (sb == null) { 214 sb = new StringBuilder(inString.length()); 215 sb.append(inString, 0, i); 216 } 217 return sb; 218 } 219 220 221 /** 222 * Accumulates byte content and returns an aggregated byte[] at the end. 223 */ 224 private interface Result { 225 226 void add(byte[] bytes); 227 228 void add(byte b); 229 230 byte[] toByteArray(); 231 } 232 233 234 @SuppressWarnings("serial") 235 private static class DefaultResult extends ArrayList<Object> implements Result { 236 237 private int size; 238 239 public void add(byte[] bytes) { 240 this.size += bytes.length; 241 super.add(bytes); 242 } 243 244 public void add(byte b) { 245 this.size++; 246 super.add(b); 247 } 248 249 public byte[] toByteArray() { 250 byte[] result = new byte[this.size]; 251 int position = 0; 252 for (Object o : this) { 253 if (o instanceof byte[]) { 254 byte[] src = (byte[]) o; 255 System.arraycopy(src, 0, result, position, src.length); 256 position += src.length; 257 } 258 else { 259 result[position++] = (Byte) o; 260 } 261 } 262 return result; 263 } 264 } 265 266}