001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp.stomp;
018
019import java.nio.charset.Charset;
020import java.security.Principal;
021import java.util.Arrays;
022import java.util.Collections;
023import java.util.List;
024import java.util.Map;
025import java.util.Set;
026import java.util.concurrent.atomic.AtomicLong;
027
028import org.springframework.messaging.Message;
029import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
030import org.springframework.messaging.simp.SimpMessageType;
031import org.springframework.messaging.support.MessageHeaderAccessor;
032import org.springframework.util.ClassUtils;
033import org.springframework.util.CollectionUtils;
034import org.springframework.util.MimeType;
035import org.springframework.util.MimeTypeUtils;
036import org.springframework.util.StringUtils;
037
038/**
039 * A {@code MessageHeaderAccessor} to use when creating a {@code Message} from
040 * a decoded STOMP frame, or when encoding a {@code Message} to a STOMP frame.
041 *
042 * <p>When created from STOMP frame content, the actual STOMP headers are
043 * stored in the native header sub-map managed by the parent class
044 * {@link org.springframework.messaging.support.NativeMessageHeaderAccessor}
045 * while the parent class {@link SimpMessageHeaderAccessor} manages common
046 * processing headers some of which are based on STOMP headers
047 * (e.g. destination, content-type, etc).
048 *
049 * <p>An instance of this class can also be created by wrapping an existing
050 * {@code Message}. That message may have been created with the more generic
051 * {@link org.springframework.messaging.simp.SimpMessageHeaderAccessor} in
052 * which case STOMP headers are created from common processing headers.
053 * In this case it is also necessary to invoke either
054 * {@link #updateStompCommandAsClientMessage()} or
055 * {@link #updateStompCommandAsServerMessage()} if sending a message and
056 * depending on whether a message is sent to a client or the message broker.
057 *
058 * @author Rossen Stoyanchev
059 * @since 4.0
060 */
061public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
062
063        private static final AtomicLong messageIdCounter = new AtomicLong();
064
065        private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0};
066
067
068        // STOMP header names
069
070        public static final String STOMP_ID_HEADER = "id";
071
072        public static final String STOMP_HOST_HEADER = "host";
073
074        public static final String STOMP_ACCEPT_VERSION_HEADER = "accept-version";
075
076        public static final String STOMP_MESSAGE_ID_HEADER = "message-id";
077
078        public static final String STOMP_RECEIPT_HEADER = "receipt"; // any client frame except CONNECT
079
080        public static final String STOMP_RECEIPT_ID_HEADER = "receipt-id"; // RECEIPT frame
081
082        public static final String STOMP_SUBSCRIPTION_HEADER = "subscription";
083
084        public static final String STOMP_VERSION_HEADER = "version";
085
086        public static final String STOMP_MESSAGE_HEADER = "message";
087
088        public static final String STOMP_ACK_HEADER = "ack";
089
090        public static final String STOMP_NACK_HEADER = "nack";
091
092        public static final String STOMP_LOGIN_HEADER = "login";
093
094        public static final String STOMP_PASSCODE_HEADER = "passcode";
095
096        public static final String STOMP_DESTINATION_HEADER = "destination";
097
098        public static final String STOMP_CONTENT_TYPE_HEADER = "content-type";
099
100        public static final String STOMP_CONTENT_LENGTH_HEADER = "content-length";
101
102        public static final String STOMP_HEARTBEAT_HEADER = "heart-beat";
103
104        // Other header names
105
106        private static final String COMMAND_HEADER = "stompCommand";
107
108        private static final String CREDENTIALS_HEADER = "stompCredentials";
109
110
111        /**
112         * A constructor for creating message headers from a parsed STOMP frame.
113         */
114        StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
115                super(command.getMessageType(), externalSourceHeaders);
116                setHeader(COMMAND_HEADER, command);
117                updateSimpMessageHeadersFromStompHeaders();
118        }
119
120        /**
121         * A constructor for accessing and modifying existing message headers.
122         * Note that the message headers may not have been created from a STOMP frame
123         * but may have rather originated from using the more generic
124         * {@link org.springframework.messaging.simp.SimpMessageHeaderAccessor}.
125         */
126        StompHeaderAccessor(Message<?> message) {
127                super(message);
128                updateStompHeadersFromSimpMessageHeaders();
129        }
130
131        StompHeaderAccessor() {
132                super(SimpMessageType.HEARTBEAT, null);
133        }
134
135
136        void updateSimpMessageHeadersFromStompHeaders() {
137                if (getNativeHeaders() == null) {
138                        return;
139                }
140                String value = getFirstNativeHeader(STOMP_DESTINATION_HEADER);
141                if (value != null) {
142                        super.setDestination(value);
143                }
144                value = getFirstNativeHeader(STOMP_CONTENT_TYPE_HEADER);
145                if (value != null) {
146                        super.setContentType(MimeTypeUtils.parseMimeType(value));
147                }
148                StompCommand command = getCommand();
149                if (StompCommand.MESSAGE.equals(command)) {
150                        value = getFirstNativeHeader(STOMP_SUBSCRIPTION_HEADER);
151                        if (value != null) {
152                                super.setSubscriptionId(value);
153                        }
154                }
155                else if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
156                        value = getFirstNativeHeader(STOMP_ID_HEADER);
157                        if (value != null) {
158                                super.setSubscriptionId(value);
159                        }
160                }
161                else if (StompCommand.CONNECT.equals(command)) {
162                        protectPasscode();
163                }
164        }
165
166        void updateStompHeadersFromSimpMessageHeaders() {
167                String destination = getDestination();
168                if (destination != null) {
169                        setNativeHeader(STOMP_DESTINATION_HEADER, destination);
170                }
171                MimeType contentType = getContentType();
172                if (contentType != null) {
173                        setNativeHeader(STOMP_CONTENT_TYPE_HEADER, contentType.toString());
174                }
175                trySetStompHeaderForSubscriptionId();
176        }
177
178
179        @Override
180        protected MessageHeaderAccessor createAccessor(Message<?> message) {
181                return wrap(message);
182        }
183
184        Map<String, List<String>> getNativeHeaders() {
185                @SuppressWarnings("unchecked")
186                Map<String, List<String>> map = (Map<String, List<String>>) getHeader(NATIVE_HEADERS);
187                return (map != null ? map : Collections.<String, List<String>>emptyMap());
188        }
189
190        public StompCommand updateStompCommandAsClientMessage() {
191                SimpMessageType messageType = getMessageType();
192                if (messageType != SimpMessageType.MESSAGE) {
193                        throw new IllegalStateException("Unexpected message type " + messageType);
194                }
195                StompCommand command = getCommand();
196                if (command == null) {
197                        command = StompCommand.SEND;
198                        setHeader(COMMAND_HEADER, command);
199                }
200                else if (!command.equals(StompCommand.SEND)) {
201                        throw new IllegalStateException("Unexpected STOMP command " + command);
202                }
203                return command;
204        }
205
206        public void updateStompCommandAsServerMessage() {
207                SimpMessageType messageType = getMessageType();
208                if (messageType != SimpMessageType.MESSAGE) {
209                        throw new IllegalStateException("Unexpected message type " + messageType);
210                }
211                StompCommand command = getCommand();
212                if ((command == null) || StompCommand.SEND.equals(command)) {
213                        setHeader(COMMAND_HEADER, StompCommand.MESSAGE);
214                }
215                else if (!StompCommand.MESSAGE.equals(command)) {
216                        throw new IllegalStateException("Unexpected STOMP command " + command);
217                }
218                trySetStompHeaderForSubscriptionId();
219                if (getMessageId() == null) {
220                        String messageId = getSessionId() + '-' + messageIdCounter.getAndIncrement();
221                        setNativeHeader(STOMP_MESSAGE_ID_HEADER, messageId);
222                }
223        }
224
225        /**
226         * Return the STOMP command, or {@code null} if not yet set.
227         */
228        public StompCommand getCommand() {
229                return (StompCommand) getHeader(COMMAND_HEADER);
230        }
231
232        public boolean isHeartbeat() {
233                return (SimpMessageType.HEARTBEAT == getMessageType());
234        }
235
236        public long[] getHeartbeat() {
237                String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER);
238                String[] rawValues = StringUtils.split(rawValue, ",");
239                if (rawValues == null) {
240                        return Arrays.copyOf(DEFAULT_HEARTBEAT, 2);
241                }
242                return new long[] {Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
243        }
244
245        public void setAcceptVersion(String acceptVersion) {
246                setNativeHeader(STOMP_ACCEPT_VERSION_HEADER, acceptVersion);
247        }
248
249        public Set<String> getAcceptVersion() {
250                String rawValue = getFirstNativeHeader(STOMP_ACCEPT_VERSION_HEADER);
251                return (rawValue != null ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet());
252        }
253
254        public void setHost(String host) {
255                setNativeHeader(STOMP_HOST_HEADER, host);
256        }
257
258        public String getHost() {
259                return getFirstNativeHeader(STOMP_HOST_HEADER);
260        }
261
262        @Override
263        public void setDestination(String destination) {
264                super.setDestination(destination);
265                setNativeHeader(STOMP_DESTINATION_HEADER, destination);
266        }
267
268        @Override
269        public void setContentType(MimeType contentType) {
270                super.setContentType(contentType);
271                setNativeHeader(STOMP_CONTENT_TYPE_HEADER, contentType.toString());
272        }
273
274        @Override
275        public void setSubscriptionId(String subscriptionId) {
276                super.setSubscriptionId(subscriptionId);
277                trySetStompHeaderForSubscriptionId();
278        }
279
280        private void trySetStompHeaderForSubscriptionId() {
281                String subscriptionId = getSubscriptionId();
282                if (subscriptionId != null) {
283                        StompCommand command = getCommand();
284                        if (command != null && StompCommand.MESSAGE.equals(command)) {
285                                setNativeHeader(STOMP_SUBSCRIPTION_HEADER, subscriptionId);
286                        }
287                        else {
288                                SimpMessageType messageType = getMessageType();
289                                if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
290                                        setNativeHeader(STOMP_ID_HEADER, subscriptionId);
291                                }
292                        }
293                }
294        }
295
296        public Integer getContentLength() {
297                String header = getFirstNativeHeader(STOMP_CONTENT_LENGTH_HEADER);
298                return (header != null ? Integer.valueOf(header) : null);
299        }
300
301        public void setContentLength(int contentLength) {
302                setNativeHeader(STOMP_CONTENT_LENGTH_HEADER, String.valueOf(contentLength));
303        }
304
305        public void setHeartbeat(long cx, long cy) {
306                setNativeHeader(STOMP_HEARTBEAT_HEADER, cx + "," + cy);
307        }
308
309        public void setAck(String ack) {
310                setNativeHeader(STOMP_ACK_HEADER, ack);
311        }
312
313        public String getAck() {
314                return getFirstNativeHeader(STOMP_ACK_HEADER);
315        }
316
317        public void setNack(String nack) {
318                setNativeHeader(STOMP_NACK_HEADER, nack);
319        }
320
321        public String getNack() {
322                return getFirstNativeHeader(STOMP_NACK_HEADER);
323        }
324
325        public void setLogin(String login) {
326                setNativeHeader(STOMP_LOGIN_HEADER, login);
327        }
328
329        public String getLogin() {
330                return getFirstNativeHeader(STOMP_LOGIN_HEADER);
331        }
332
333        public void setPasscode(String passcode) {
334                setNativeHeader(STOMP_PASSCODE_HEADER, passcode);
335                protectPasscode();
336        }
337
338        private void protectPasscode() {
339                String value = getFirstNativeHeader(STOMP_PASSCODE_HEADER);
340                if (value != null && !"PROTECTED".equals(value)) {
341                        setHeader(CREDENTIALS_HEADER, new StompPasscode(value));
342                        setNativeHeader(STOMP_PASSCODE_HEADER, "PROTECTED");
343                }
344        }
345
346        /**
347         * Return the passcode header value, or {@code null} if not set.
348         */
349        public String getPasscode() {
350                StompPasscode credentials = (StompPasscode) getHeader(CREDENTIALS_HEADER);
351                return (credentials != null ? credentials.passcode : null);
352        }
353
354        public void setReceiptId(String receiptId) {
355                setNativeHeader(STOMP_RECEIPT_ID_HEADER, receiptId);
356        }
357
358        public String getReceiptId() {
359                return getFirstNativeHeader(STOMP_RECEIPT_ID_HEADER);
360        }
361
362        public void setReceipt(String receiptId) {
363                setNativeHeader(STOMP_RECEIPT_HEADER, receiptId);
364        }
365
366        public String getReceipt() {
367                return getFirstNativeHeader(STOMP_RECEIPT_HEADER);
368        }
369
370        public String getMessage() {
371                return getFirstNativeHeader(STOMP_MESSAGE_HEADER);
372        }
373
374        public void setMessage(String content) {
375                setNativeHeader(STOMP_MESSAGE_HEADER, content);
376        }
377
378        public String getMessageId() {
379                return getFirstNativeHeader(STOMP_MESSAGE_ID_HEADER);
380        }
381
382        public void setMessageId(String id) {
383                setNativeHeader(STOMP_MESSAGE_ID_HEADER, id);
384        }
385
386        public String getVersion() {
387                return getFirstNativeHeader(STOMP_VERSION_HEADER);
388        }
389
390        public void setVersion(String version) {
391                setNativeHeader(STOMP_VERSION_HEADER, version);
392        }
393
394
395        // Logging related
396
397        @Override
398        public String getShortLogMessage(Object payload) {
399                StompCommand command = getCommand();
400                if (StompCommand.SUBSCRIBE.equals(command)) {
401                        return "SUBSCRIBE " + getDestination() + " id=" + getSubscriptionId() + appendSession();
402                }
403                else if (StompCommand.UNSUBSCRIBE.equals(command)) {
404                        return "UNSUBSCRIBE id=" + getSubscriptionId() + appendSession();
405                }
406                else if (StompCommand.SEND.equals(command)) {
407                        return "SEND " + getDestination() + appendSession() + appendPayload(payload);
408                }
409                else if (StompCommand.CONNECT.equals(command)) {
410                        Principal user = getUser();
411                        return "CONNECT" + (user != null ? " user=" + user.getName() : "") + appendSession();
412                }
413                else if (StompCommand.CONNECTED.equals(command)) {
414                        return "CONNECTED heart-beat=" + Arrays.toString(getHeartbeat()) + appendSession();
415                }
416                else if (StompCommand.DISCONNECT.equals(command)) {
417                        String receipt = getReceipt();
418                        return "DISCONNECT" + (receipt != null ? " receipt=" + receipt : "") + appendSession();
419                }
420                else {
421                        return getDetailedLogMessage(payload);
422                }
423        }
424
425        @Override
426        public String getDetailedLogMessage(Object payload) {
427                if (isHeartbeat()) {
428                        String sessionId = getSessionId();
429                        return "heart-beat" + (sessionId != null ? " in session " + sessionId : "");
430                }
431                StompCommand command = getCommand();
432                if (command == null) {
433                        return super.getDetailedLogMessage(payload);
434                }
435                StringBuilder sb = new StringBuilder();
436                sb.append(command.name()).append(" ").append(getNativeHeaders()).append(appendSession());
437                if (getUser() != null) {
438                        sb.append(", user=").append(getUser().getName());
439                }
440                if (command.isBodyAllowed()) {
441                        sb.append(appendPayload(payload));
442                }
443                return sb.toString();
444        }
445
446        private String appendSession() {
447                return " session=" + getSessionId();
448        }
449
450        private String appendPayload(Object payload) {
451                if (payload.getClass() != byte[].class) {
452                        throw new IllegalStateException(
453                                        "Expected byte array payload but got: " + ClassUtils.getQualifiedName(payload.getClass()));
454                }
455                byte[] bytes = (byte[]) payload;
456                MimeType mimeType = getContentType();
457                String contentType = (mimeType != null ? " " + mimeType.toString() : "");
458                if (bytes.length == 0 || mimeType == null || !isReadableContentType()) {
459                        return contentType;
460                }
461                Charset charset = mimeType.getCharset();
462                charset = (charset != null ? charset : StompDecoder.UTF8_CHARSET);
463                return (bytes.length < 80) ?
464                                contentType + " payload=" + new String(bytes, charset) :
465                                contentType + " payload=" + new String(Arrays.copyOf(bytes, 80), charset) + "...(truncated)";
466        }
467
468
469        // Static factory methods and accessors
470
471        /**
472         * Create an instance for the given STOMP command.
473         */
474        public static StompHeaderAccessor create(StompCommand command) {
475                return new StompHeaderAccessor(command, null);
476        }
477
478        /**
479         * Create an instance for the given STOMP command and headers.
480         */
481        public static StompHeaderAccessor create(StompCommand command, Map<String, List<String>> headers) {
482                return new StompHeaderAccessor(command, headers);
483        }
484
485        /**
486         * Create headers for a heartbeat. While a STOMP heartbeat frame does not
487         * have headers, a session id is needed for processing purposes at a minimum.
488         */
489        public static StompHeaderAccessor createForHeartbeat() {
490                return new StompHeaderAccessor();
491        }
492
493        /**
494         * Create an instance from the payload and headers of the given Message.
495         */
496        public static StompHeaderAccessor wrap(Message<?> message) {
497                return new StompHeaderAccessor(message);
498        }
499
500        /**
501         * Return the STOMP command from the given headers, or {@code null} if not set.
502         */
503        public static StompCommand getCommand(Map<String, Object> headers) {
504                return (StompCommand) headers.get(COMMAND_HEADER);
505        }
506
507        /**
508         * Return the passcode header value, or {@code null} if not set.
509         */
510        public static String getPasscode(Map<String, Object> headers) {
511                StompPasscode credentials = (StompPasscode) headers.get(CREDENTIALS_HEADER);
512                return (credentials != null ? credentials.passcode : null);
513        }
514
515        public static Integer getContentLength(Map<String, List<String>> nativeHeaders) {
516                List<String> values = nativeHeaders.get(STOMP_CONTENT_LENGTH_HEADER);
517                return (!CollectionUtils.isEmpty(values) ? Integer.valueOf(values.get(0)) : null);
518        }
519
520
521        private static class StompPasscode {
522
523                private final String passcode;
524
525                public StompPasscode(String passcode) {
526                        this.passcode = passcode;
527                }
528
529                @Override
530                public String toString() {
531                        return "[PROTECTED]";
532                }
533        }
534
535}