001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.messaging;
018
019import java.security.Principal;
020import java.util.HashSet;
021import java.util.Map;
022import java.util.Set;
023import java.util.concurrent.ConcurrentHashMap;
024
025import org.springframework.context.ApplicationEvent;
026import org.springframework.context.event.SmartApplicationListener;
027import org.springframework.core.Ordered;
028import org.springframework.messaging.Message;
029import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
030import org.springframework.messaging.simp.user.DestinationUserNameProvider;
031import org.springframework.messaging.simp.user.SimpSession;
032import org.springframework.messaging.simp.user.SimpSubscription;
033import org.springframework.messaging.simp.user.SimpSubscriptionMatcher;
034import org.springframework.messaging.simp.user.SimpUser;
035import org.springframework.messaging.simp.user.SimpUserRegistry;
036import org.springframework.messaging.support.MessageHeaderAccessor;
037import org.springframework.util.Assert;
038
039/**
040 * A default implementation of {@link SimpUserRegistry} that relies on
041 * {@link AbstractSubProtocolEvent} application context events to keep track of
042 * connected users and their subscriptions.
043 *
044 * @author Rossen Stoyanchev
045 * @since 4.2
046 */
047public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener {
048
049        /* Primary lookup that holds all users and their sessions */
050        private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<String, LocalSimpUser>();
051
052        /* Secondary lookup across all sessions by id */
053        private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<String, LocalSimpSession>();
054
055        private final Object sessionLock = new Object();
056
057
058        @Override
059        public int getOrder() {
060                return Ordered.LOWEST_PRECEDENCE;
061        }
062
063
064        // SmartApplicationListener methods
065
066        @Override
067        public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
068                return AbstractSubProtocolEvent.class.isAssignableFrom(eventType);
069        }
070
071        @Override
072        public void onApplicationEvent(ApplicationEvent event) {
073                AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
074                Message<?> message = subProtocolEvent.getMessage();
075                SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
076                String sessionId = accessor.getSessionId();
077
078                if (event instanceof SessionSubscribeEvent) {
079                        LocalSimpSession session = this.sessions.get(sessionId);
080                        if (session != null) {
081                                String id = accessor.getSubscriptionId();
082                                String destination = accessor.getDestination();
083                                session.addSubscription(id, destination);
084                        }
085                }
086                else if (event instanceof SessionConnectedEvent) {
087                        Principal user = subProtocolEvent.getUser();
088                        if (user == null) {
089                                return;
090                        }
091                        String name = user.getName();
092                        if (user instanceof DestinationUserNameProvider) {
093                                name = ((DestinationUserNameProvider) user).getDestinationUserName();
094                        }
095                        synchronized (this.sessionLock) {
096                                LocalSimpUser simpUser = this.users.get(name);
097                                if (simpUser == null) {
098                                        simpUser = new LocalSimpUser(name);
099                                        this.users.put(name, simpUser);
100                                }
101                                LocalSimpSession session = new LocalSimpSession(sessionId, simpUser);
102                                simpUser.addSession(session);
103                                this.sessions.put(sessionId, session);
104                        }
105                }
106                else if (event instanceof SessionDisconnectEvent) {
107                        synchronized (this.sessionLock) {
108                                LocalSimpSession session = this.sessions.remove(sessionId);
109                                if (session != null) {
110                                        LocalSimpUser user = session.getUser();
111                                        user.removeSession(sessionId);
112                                        if (!user.hasSessions()) {
113                                                this.users.remove(user.getName());
114                                        }
115                                }
116                        }
117                }
118                else if (event instanceof SessionUnsubscribeEvent) {
119                        LocalSimpSession session = this.sessions.get(sessionId);
120                        if (session != null) {
121                                String subscriptionId = accessor.getSubscriptionId();
122                                session.removeSubscription(subscriptionId);
123                        }
124                }
125        }
126
127        @Override
128        public boolean supportsSourceType(Class<?> sourceType) {
129                return true;
130        }
131
132
133        // SimpUserRegistry methods
134
135        @Override
136        public SimpUser getUser(String userName) {
137                return this.users.get(userName);
138        }
139
140        @Override
141        public Set<SimpUser> getUsers() {
142                return new HashSet<SimpUser>(this.users.values());
143        }
144
145        @Override
146        public int getUserCount() {
147                return this.users.size();
148        }
149
150        public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
151                Set<SimpSubscription> result = new HashSet<SimpSubscription>();
152                for (LocalSimpSession session : this.sessions.values()) {
153                        for (SimpSubscription subscription : session.subscriptions.values()) {
154                                if (matcher.match(subscription)) {
155                                        result.add(subscription);
156                                }
157                        }
158                }
159                return result;
160        }
161
162
163        @Override
164        public String toString() {
165                return "users=" + this.users;
166        }
167
168
169        private static class LocalSimpUser implements SimpUser {
170
171                private final String name;
172
173                private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<String, SimpSession>(1);
174
175                public LocalSimpUser(String userName) {
176                        Assert.notNull(userName, "User name must not be null");
177                        this.name = userName;
178                }
179
180                @Override
181                public String getName() {
182                        return this.name;
183                }
184
185                @Override
186                public boolean hasSessions() {
187                        return !this.userSessions.isEmpty();
188                }
189
190                @Override
191                public SimpSession getSession(String sessionId) {
192                        return (sessionId != null ? this.userSessions.get(sessionId) : null);
193                }
194
195                @Override
196                public Set<SimpSession> getSessions() {
197                        return new HashSet<SimpSession>(this.userSessions.values());
198                }
199
200                void addSession(SimpSession session) {
201                        this.userSessions.put(session.getId(), session);
202                }
203
204                void removeSession(String sessionId) {
205                        this.userSessions.remove(sessionId);
206                }
207
208                @Override
209                public boolean equals(Object other) {
210                        return (this == other ||
211                                        (other instanceof SimpUser && this.name.equals(((SimpUser) other).getName())));
212                }
213
214                @Override
215                public int hashCode() {
216                        return this.name.hashCode();
217                }
218
219                @Override
220                public String toString() {
221                        return "name=" + this.name + ", sessions=" + this.userSessions;
222                }
223        }
224
225
226        private static class LocalSimpSession implements SimpSession {
227
228                private final String id;
229
230                private final LocalSimpUser user;
231
232                private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<String, SimpSubscription>(4);
233
234                public LocalSimpSession(String id, LocalSimpUser user) {
235                        Assert.notNull(id, "Id must not be null");
236                        Assert.notNull(user, "User must not be null");
237                        this.id = id;
238                        this.user = user;
239                }
240
241                @Override
242                public String getId() {
243                        return this.id;
244                }
245
246                @Override
247                public LocalSimpUser getUser() {
248                        return this.user;
249                }
250
251                @Override
252                public Set<SimpSubscription> getSubscriptions() {
253                        return new HashSet<SimpSubscription>(this.subscriptions.values());
254                }
255
256                void addSubscription(String id, String destination) {
257                        this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this));
258                }
259
260                void removeSubscription(String id) {
261                        this.subscriptions.remove(id);
262                }
263
264                @Override
265                public boolean equals(Object other) {
266                        return (this == other ||
267                                        (other instanceof SimpSubscription && this.id.equals(((SimpSubscription) other).getId())));
268                }
269
270                @Override
271                public int hashCode() {
272                        return this.id.hashCode();
273                }
274
275                @Override
276                public String toString() {
277                        return "id=" + this.id + ", subscriptions=" + this.subscriptions;
278                }
279        }
280
281
282        private static class LocalSimpSubscription implements SimpSubscription {
283
284                private final String id;
285
286                private final LocalSimpSession session;
287
288                private final String destination;
289
290                public LocalSimpSubscription(String id, String destination, LocalSimpSession session) {
291                        Assert.notNull(id, "Id must not be null");
292                        Assert.hasText(destination, "Destination must not be empty");
293                        Assert.notNull(session, "Session must not be null");
294                        this.id = id;
295                        this.destination = destination;
296                        this.session = session;
297                }
298
299                @Override
300                public String getId() {
301                        return this.id;
302                }
303
304                @Override
305                public LocalSimpSession getSession() {
306                        return this.session;
307                }
308
309                @Override
310                public String getDestination() {
311                        return this.destination;
312                }
313
314                @Override
315                public boolean equals(Object other) {
316                        if (this == other) {
317                                return true;
318                        }
319                        if (!(other instanceof SimpSubscription)) {
320                                return false;
321                        }
322                        SimpSubscription otherSubscription = (SimpSubscription) other;
323                        return (this.id.equals(otherSubscription.getId()) &&
324                                        getSession().getId().equals(otherSubscription.getSession().getId()));
325                }
326
327                @Override
328                public int hashCode() {
329                        return this.id.hashCode() * 31 + getSession().getId().hashCode();
330                }
331
332                @Override
333                public String toString() {
334                        return "destination=" + this.destination;
335                }
336        }
337
338}