001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.security.Principal;
020import java.util.HashSet;
021import java.util.Map;
022import java.util.Set;
023import java.util.concurrent.ConcurrentHashMap;
024
025import org.springframework.context.ApplicationEvent;
026import org.springframework.context.event.SmartApplicationListener;
027import org.springframework.core.Ordered;
028import org.springframework.lang.Nullable;
029import org.springframework.messaging.Message;
030import org.springframework.messaging.MessageHeaders;
031import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
032import org.springframework.messaging.simp.user.DestinationUserNameProvider;
033import org.springframework.messaging.simp.user.SimpSession;
034import org.springframework.messaging.simp.user.SimpSubscription;
035import org.springframework.messaging.simp.user.SimpSubscriptionMatcher;
036import org.springframework.messaging.simp.user.SimpUser;
037import org.springframework.messaging.simp.user.SimpUserRegistry;
038import org.springframework.util.Assert;
039
040/**
041 * A default implementation of {@link SimpUserRegistry} that relies on
042 * {@link AbstractSubProtocolEvent} application context events to keep
043 * track of connected users and their subscriptions.
044 *
045 * @author Rossen Stoyanchev
046 * @since 4.2
047 */
048public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener {
049
050        private int order = Ordered.LOWEST_PRECEDENCE;
051
052        /* Primary lookup that holds all users and their sessions */
053        private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<>();
054
055        /* Secondary lookup across all sessions by id */
056        private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<>();
057
058        private final Object sessionLock = new Object();
059
060
061        /**
062         * Specify the order value for this registry.
063         * <p>Default is {@link Ordered#LOWEST_PRECEDENCE}.
064         * @since 5.0.8
065         */
066        public void setOrder(int order) {
067                this.order = order;
068        }
069
070        @Override
071        public int getOrder() {
072                return this.order;
073        }
074
075
076        // SmartApplicationListener methods
077
078        @Override
079        public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
080                return AbstractSubProtocolEvent.class.isAssignableFrom(eventType);
081        }
082
083        @Override
084        public void onApplicationEvent(ApplicationEvent event) {
085                AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
086                Message<?> message = subProtocolEvent.getMessage();
087                MessageHeaders headers = message.getHeaders();
088
089                String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
090                Assert.state(sessionId != null, "No session id");
091
092                if (event instanceof SessionSubscribeEvent) {
093                        LocalSimpSession session = this.sessions.get(sessionId);
094                        if (session != null) {
095                                String id = SimpMessageHeaderAccessor.getSubscriptionId(headers);
096                                String destination = SimpMessageHeaderAccessor.getDestination(headers);
097                                if (id != null && destination != null) {
098                                        session.addSubscription(id, destination);
099                                }
100                        }
101                }
102                else if (event instanceof SessionConnectedEvent) {
103                        Principal user = subProtocolEvent.getUser();
104                        if (user == null) {
105                                return;
106                        }
107                        String name = user.getName();
108                        if (user instanceof DestinationUserNameProvider) {
109                                name = ((DestinationUserNameProvider) user).getDestinationUserName();
110                        }
111                        synchronized (this.sessionLock) {
112                                LocalSimpUser simpUser = this.users.get(name);
113                                if (simpUser == null) {
114                                        simpUser = new LocalSimpUser(name);
115                                        this.users.put(name, simpUser);
116                                }
117                                LocalSimpSession session = new LocalSimpSession(sessionId, simpUser);
118                                simpUser.addSession(session);
119                                this.sessions.put(sessionId, session);
120                        }
121                }
122                else if (event instanceof SessionDisconnectEvent) {
123                        synchronized (this.sessionLock) {
124                                LocalSimpSession session = this.sessions.remove(sessionId);
125                                if (session != null) {
126                                        LocalSimpUser user = session.getUser();
127                                        user.removeSession(sessionId);
128                                        if (!user.hasSessions()) {
129                                                this.users.remove(user.getName());
130                                        }
131                                }
132                        }
133                }
134                else if (event instanceof SessionUnsubscribeEvent) {
135                        LocalSimpSession session = this.sessions.get(sessionId);
136                        if (session != null) {
137                                String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
138                                if (subscriptionId != null) {
139                                        session.removeSubscription(subscriptionId);
140                                }
141                        }
142                }
143        }
144
145        @Override
146        public boolean supportsSourceType(@Nullable Class<?> sourceType) {
147                return true;
148        }
149
150
151        // SimpUserRegistry methods
152
153        @Override
154        @Nullable
155        public SimpUser getUser(String userName) {
156                return this.users.get(userName);
157        }
158
159        @Override
160        public Set<SimpUser> getUsers() {
161                return new HashSet<>(this.users.values());
162        }
163
164        @Override
165        public int getUserCount() {
166                return this.users.size();
167        }
168
169        @Override
170        public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
171                Set<SimpSubscription> result = new HashSet<>();
172                for (LocalSimpSession session : this.sessions.values()) {
173                        for (SimpSubscription subscription : session.subscriptions.values()) {
174                                if (matcher.match(subscription)) {
175                                        result.add(subscription);
176                                }
177                        }
178                }
179                return result;
180        }
181
182
183        @Override
184        public String toString() {
185                return "users=" + this.users;
186        }
187
188
189        private static class LocalSimpUser implements SimpUser {
190
191                private final String name;
192
193                private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<>(1);
194
195                public LocalSimpUser(String userName) {
196                        Assert.notNull(userName, "User name must not be null");
197                        this.name = userName;
198                }
199
200                @Override
201                public String getName() {
202                        return this.name;
203                }
204
205                @Override
206                public boolean hasSessions() {
207                        return !this.userSessions.isEmpty();
208                }
209
210                @Override
211                @Nullable
212                public SimpSession getSession(@Nullable String sessionId) {
213                        return (sessionId != null ? this.userSessions.get(sessionId) : null);
214                }
215
216                @Override
217                public Set<SimpSession> getSessions() {
218                        return new HashSet<>(this.userSessions.values());
219                }
220
221                void addSession(SimpSession session) {
222                        this.userSessions.put(session.getId(), session);
223                }
224
225                void removeSession(String sessionId) {
226                        this.userSessions.remove(sessionId);
227                }
228
229                @Override
230                public boolean equals(@Nullable Object other) {
231                        return (this == other ||
232                                        (other instanceof SimpUser && getName().equals(((SimpUser) other).getName())));
233                }
234
235                @Override
236                public int hashCode() {
237                        return getName().hashCode();
238                }
239
240                @Override
241                public String toString() {
242                        return "name=" + getName() + ", sessions=" + this.userSessions;
243                }
244        }
245
246
247        private static class LocalSimpSession implements SimpSession {
248
249                private final String id;
250
251                private final LocalSimpUser user;
252
253                private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<>(4);
254
255                public LocalSimpSession(String id, LocalSimpUser user) {
256                        Assert.notNull(id, "Id must not be null");
257                        Assert.notNull(user, "User must not be null");
258                        this.id = id;
259                        this.user = user;
260                }
261
262                @Override
263                public String getId() {
264                        return this.id;
265                }
266
267                @Override
268                public LocalSimpUser getUser() {
269                        return this.user;
270                }
271
272                @Override
273                public Set<SimpSubscription> getSubscriptions() {
274                        return new HashSet<>(this.subscriptions.values());
275                }
276
277                void addSubscription(String id, String destination) {
278                        this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this));
279                }
280
281                void removeSubscription(String id) {
282                        this.subscriptions.remove(id);
283                }
284
285                @Override
286                public boolean equals(@Nullable Object other) {
287                        return (this == other ||
288                                        (other instanceof SimpSubscription && getId().equals(((SimpSubscription) other).getId())));
289                }
290
291                @Override
292                public int hashCode() {
293                        return getId().hashCode();
294                }
295
296                @Override
297                public String toString() {
298                        return "id=" + getId() + ", subscriptions=" + this.subscriptions;
299                }
300        }
301
302
303        private static class LocalSimpSubscription implements SimpSubscription {
304
305                private final String id;
306
307                private final LocalSimpSession session;
308
309                private final String destination;
310
311                public LocalSimpSubscription(String id, String destination, LocalSimpSession session) {
312                        Assert.notNull(id, "Id must not be null");
313                        Assert.hasText(destination, "Destination must not be empty");
314                        Assert.notNull(session, "Session must not be null");
315                        this.id = id;
316                        this.destination = destination;
317                        this.session = session;
318                }
319
320                @Override
321                public String getId() {
322                        return this.id;
323                }
324
325                @Override
326
327                        return this.session;
328                }
329
330                @Override
331                public String getDestination() {
332                        return this.destination;
333                }
334
335                @Override
336                public boolean equals(@Nullable Object other) {
337                        if (this == other) {
338                                return true;
339                        }
340                        if (!(other instanceof SimpSubscription)) {
341                                return false;
342                        }
343                        SimpSubscription otherSubscription = (SimpSubscription) other;
344                        return (getId().equals(otherSubscription.getId()) &&
345                                        getSession().getId().equals(otherSubscription.getSession().getId()));
346                }
347
348                @Override
349                public int hashCode() {
350                        return getId().hashCode() * 31 + getSession().getId().hashCode();
351                }
352
353                @Override
354                public String toString() {
355                        return "destination=" + this.destination;
356                }
357        }
358
359}