001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.messaging.simp;
018
019import java.security.Principal;
020import java.util.List;
021import java.util.Map;
022import java.util.function.Consumer;
023
024import org.springframework.lang.Nullable;
025import org.springframework.messaging.Message;
026import org.springframework.messaging.support.IdTimestampMessageHeaderInitializer;
027import org.springframework.messaging.support.MessageHeaderAccessor;
028import org.springframework.messaging.support.NativeMessageHeaderAccessor;
029import org.springframework.util.Assert;
030import org.springframework.util.CollectionUtils;
031
032/**
033 * A base class for working with message headers in simple messaging protocols that
034 * support basic messaging patterns. Provides uniform access to specific values common
035 * across protocols such as a destination, message type (e.g. publish, subscribe, etc),
036 * session id, and others.
037 *
038 * <p>Use one of the static factory method in this class, then call getters and setters,
039 * and at the end if necessary call {@link #toMap()} to obtain the updated headers.
040 *
041 * @author Rossen Stoyanchev
042 * @since 4.0
043 */
044public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
045
046        private static final IdTimestampMessageHeaderInitializer headerInitializer;
047
048        static {
049                headerInitializer = new IdTimestampMessageHeaderInitializer();
050                headerInitializer.setDisableIdGeneration();
051                headerInitializer.setEnableTimestamp(false);
052        }
053
054        // SiMP header names
055
056        public static final String DESTINATION_HEADER = "simpDestination";
057
058        public static final String MESSAGE_TYPE_HEADER = "simpMessageType";
059
060        public static final String SESSION_ID_HEADER = "simpSessionId";
061
062        public static final String SESSION_ATTRIBUTES = "simpSessionAttributes";
063
064        public static final String SUBSCRIPTION_ID_HEADER = "simpSubscriptionId";
065
066        public static final String USER_HEADER = "simpUser";
067
068        public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage";
069
070        public static final String DISCONNECT_MESSAGE_HEADER = "simpDisconnectMessage";
071
072        public static final String HEART_BEAT_HEADER = "simpHeartbeat";
073
074
075        /**
076         * A header for internal use with "user" destinations where we need to
077         * restore the destination prior to sending messages to clients.
078         */
079        public static final String ORIGINAL_DESTINATION = "simpOrigDestination";
080
081        /**
082         * A header that indicates to the broker that the sender will ignore errors.
083         * The header is simply checked for presence or absence.
084         */
085        public static final String IGNORE_ERROR = "simpIgnoreError";
086
087
088        @Nullable
089        private Consumer<Principal> userCallback;
090
091
092        /**
093         * A constructor for creating new message headers.
094         * This constructor is protected. See factory methods in this and sub-classes.
095         */
096        protected SimpMessageHeaderAccessor(SimpMessageType messageType,
097                        @Nullable Map<String, List<String>> externalSourceHeaders) {
098
099                super(externalSourceHeaders);
100                Assert.notNull(messageType, "MessageType must not be null");
101                setHeader(MESSAGE_TYPE_HEADER, messageType);
102                headerInitializer.initHeaders(this);
103        }
104
105        /**
106         * A constructor for accessing and modifying existing message headers. This
107         * constructor is protected. See factory methods in this and sub-classes.
108         */
109        protected SimpMessageHeaderAccessor(Message<?> message) {
110                super(message);
111                headerInitializer.initHeaders(this);
112        }
113
114
115        @Override
116        protected MessageHeaderAccessor createAccessor(Message<?> message) {
117                return wrap(message);
118        }
119
120        public void setMessageTypeIfNotSet(SimpMessageType messageType) {
121                if (getMessageType() == null) {
122                        setHeader(MESSAGE_TYPE_HEADER, messageType);
123                }
124        }
125
126        @Nullable
127        public SimpMessageType getMessageType() {
128                return (SimpMessageType) getHeader(MESSAGE_TYPE_HEADER);
129        }
130
131        public void setDestination(@Nullable String destination) {
132                setHeader(DESTINATION_HEADER, destination);
133        }
134
135        @Nullable
136        public String getDestination() {
137                return (String) getHeader(DESTINATION_HEADER);
138        }
139
140        public void setSubscriptionId(@Nullable String subscriptionId) {
141                setHeader(SUBSCRIPTION_ID_HEADER, subscriptionId);
142        }
143
144        @Nullable
145        public String getSubscriptionId() {
146                return (String) getHeader(SUBSCRIPTION_ID_HEADER);
147        }
148
149        public void setSessionId(@Nullable String sessionId) {
150                setHeader(SESSION_ID_HEADER, sessionId);
151        }
152
153        /**
154         * Return the id of the current session.
155         */
156        @Nullable
157        public String getSessionId() {
158                return (String) getHeader(SESSION_ID_HEADER);
159        }
160
161        /**
162         * A static alternative for access to the session attributes header.
163         */
164        public void setSessionAttributes(@Nullable Map<String, Object> attributes) {
165                setHeader(SESSION_ATTRIBUTES, attributes);
166        }
167
168        /**
169         * Return the attributes associated with the current session.
170         */
171        @SuppressWarnings("unchecked")
172        @Nullable
173        public Map<String, Object> getSessionAttributes() {
174                return (Map<String, Object>) getHeader(SESSION_ATTRIBUTES);
175        }
176
177        public void setUser(@Nullable Principal principal) {
178                setHeader(USER_HEADER, principal);
179                if (this.userCallback != null) {
180                        this.userCallback.accept(principal);
181                }
182        }
183
184        /**
185         * Return the user associated with the current session.
186         */
187        @Nullable
188        public Principal getUser() {
189                return (Principal) getHeader(USER_HEADER);
190        }
191
192        /**
193         * Provide a callback to be invoked if and when {@link #setUser(Principal)}
194         * is called. This is used internally on the inbound channel to detect
195         * token-based authentications through an interceptor.
196         * @param callback the callback to invoke
197         * @since 5.1.9
198         */
199        public void setUserChangeCallback(Consumer<Principal> callback) {
200                Assert.notNull(callback, "'callback' is required");
201                this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
202        }
203
204        @Override
205        public String getShortLogMessage(Object payload) {
206                if (getMessageType() == null) {
207                        return super.getDetailedLogMessage(payload);
208                }
209                StringBuilder sb = getBaseLogMessage();
210                if (!CollectionUtils.isEmpty(getSessionAttributes())) {
211                        sb.append(" attributes[").append(getSessionAttributes().size()).append("]");
212                }
213                sb.append(getShortPayloadLogMessage(payload));
214                return sb.toString();
215        }
216
217        @SuppressWarnings("unchecked")
218        @Override
219        public String getDetailedLogMessage(@Nullable Object payload) {
220                if (getMessageType() == null) {
221                        return super.getDetailedLogMessage(payload);
222                }
223                StringBuilder sb = getBaseLogMessage();
224                if (!CollectionUtils.isEmpty(getSessionAttributes())) {
225                        sb.append(" attributes=").append(getSessionAttributes());
226                }
227                if (!CollectionUtils.isEmpty((Map<String, List<String>>) getHeader(NATIVE_HEADERS))) {
228                        sb.append(" nativeHeaders=").append(getHeader(NATIVE_HEADERS));
229                }
230                sb.append(getDetailedPayloadLogMessage(payload));
231                return sb.toString();
232        }
233
234        private StringBuilder getBaseLogMessage() {
235                StringBuilder sb = new StringBuilder();
236                SimpMessageType messageType = getMessageType();
237                sb.append(messageType != null ? messageType.name() : SimpMessageType.OTHER);
238                String destination = getDestination();
239                if (destination != null) {
240                        sb.append(" destination=").append(destination);
241                }
242                String subscriptionId = getSubscriptionId();
243                if (subscriptionId != null) {
244                        sb.append(" subscriptionId=").append(subscriptionId);
245                }
246                sb.append(" session=").append(getSessionId());
247                Principal user = getUser();
248                if (user != null) {
249                        sb.append(" user=").append(user.getName());
250                }
251                return sb;
252        }
253
254
255        // Static factory methods and accessors
256
257        /**
258         * Create an instance with
259         * {@link org.springframework.messaging.simp.SimpMessageType} {@code MESSAGE}.
260         */
261        public static SimpMessageHeaderAccessor create() {
262                return new SimpMessageHeaderAccessor(SimpMessageType.MESSAGE, null);
263        }
264
265        /**
266         * Create an instance with the given
267         * {@link org.springframework.messaging.simp.SimpMessageType}.
268         */
269        public static SimpMessageHeaderAccessor create(SimpMessageType messageType) {
270                return new SimpMessageHeaderAccessor(messageType, null);
271        }
272
273        /**
274         * Create an instance from the payload and headers of the given Message.
275         */
276        public static SimpMessageHeaderAccessor wrap(Message<?> message) {
277                return new SimpMessageHeaderAccessor(message);
278        }
279
280        @Nullable
281        public static SimpMessageType getMessageType(Map<String, Object> headers) {
282                return (SimpMessageType) headers.get(MESSAGE_TYPE_HEADER);
283        }
284
285        @Nullable
286        public static String getDestination(Map<String, Object> headers) {
287                return (String) headers.get(DESTINATION_HEADER);
288        }
289
290        @Nullable
291        public static String getSubscriptionId(Map<String, Object> headers) {
292                return (String) headers.get(SUBSCRIPTION_ID_HEADER);
293        }
294
295        @Nullable
296        public static String getSessionId(Map<String, Object> headers) {
297                return (String) headers.get(SESSION_ID_HEADER);
298        }
299
300        @SuppressWarnings("unchecked")
301        @Nullable
302        public static Map<String, Object> getSessionAttributes(Map<String, Object> headers) {
303                return (Map<String, Object>) headers.get(SESSION_ATTRIBUTES);
304        }
305
306        @Nullable
307        public static Principal getUser(Map<String, Object> headers) {
308                return (Principal) headers.get(USER_HEADER);
309        }
310
311        @Nullable
312        public static long[] getHeartbeat(Map<String, Object> headers) {
313                return (long[]) headers.get(HEART_BEAT_HEADER);
314        }
315
316}