001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.security.Principal; 020import java.util.HashSet; 021import java.util.Map; 022import java.util.Set; 023import java.util.concurrent.ConcurrentHashMap; 024 025import org.springframework.context.ApplicationEvent; 026import org.springframework.context.event.SmartApplicationListener; 027import org.springframework.core.Ordered; 028import org.springframework.messaging.Message; 029import org.springframework.messaging.simp.SimpMessageHeaderAccessor; 030import org.springframework.messaging.simp.user.DestinationUserNameProvider; 031import org.springframework.messaging.simp.user.SimpSession; 032import org.springframework.messaging.simp.user.SimpSubscription; 033import org.springframework.messaging.simp.user.SimpSubscriptionMatcher; 034import org.springframework.messaging.simp.user.SimpUser; 035import org.springframework.messaging.simp.user.SimpUserRegistry; 036import org.springframework.messaging.support.MessageHeaderAccessor; 037import org.springframework.util.Assert; 038 039/** 040 * A default implementation of {@link SimpUserRegistry} that relies on 041 * {@link AbstractSubProtocolEvent} application context events to keep track of 042 * connected users and their subscriptions. 043 * 044 * @author Rossen Stoyanchev 045 * @since 4.2 046 */ 047public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener { 048 049 /* Primary lookup that holds all users and their sessions */ 050 private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<String, LocalSimpUser>(); 051 052 /* Secondary lookup across all sessions by id */ 053 private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<String, LocalSimpSession>(); 054 055 private final Object sessionLock = new Object(); 056 057 058 @Override 059 public int getOrder() { 060 return Ordered.LOWEST_PRECEDENCE; 061 } 062 063 064 // SmartApplicationListener methods 065 066 @Override 067 public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) { 068 return AbstractSubProtocolEvent.class.isAssignableFrom(eventType); 069 } 070 071 @Override 072 public void onApplicationEvent(ApplicationEvent event) { 073 AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event; 074 Message<?> message = subProtocolEvent.getMessage(); 075 SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); 076 String sessionId = accessor.getSessionId(); 077 078 if (event instanceof SessionSubscribeEvent) { 079 LocalSimpSession session = this.sessions.get(sessionId); 080 if (session != null) { 081 String id = accessor.getSubscriptionId(); 082 String destination = accessor.getDestination(); 083 session.addSubscription(id, destination); 084 } 085 } 086 else if (event instanceof SessionConnectedEvent) { 087 Principal user = subProtocolEvent.getUser(); 088 if (user == null) { 089 return; 090 } 091 String name = user.getName(); 092 if (user instanceof DestinationUserNameProvider) { 093 name = ((DestinationUserNameProvider) user).getDestinationUserName(); 094 } 095 synchronized (this.sessionLock) { 096 LocalSimpUser simpUser = this.users.get(name); 097 if (simpUser == null) { 098 simpUser = new LocalSimpUser(name); 099 this.users.put(name, simpUser); 100 } 101 LocalSimpSession session = new LocalSimpSession(sessionId, simpUser); 102 simpUser.addSession(session); 103 this.sessions.put(sessionId, session); 104 } 105 } 106 else if (event instanceof SessionDisconnectEvent) { 107 synchronized (this.sessionLock) { 108 LocalSimpSession session = this.sessions.remove(sessionId); 109 if (session != null) { 110 LocalSimpUser user = session.getUser(); 111 user.removeSession(sessionId); 112 if (!user.hasSessions()) { 113 this.users.remove(user.getName()); 114 } 115 } 116 } 117 } 118 else if (event instanceof SessionUnsubscribeEvent) { 119 LocalSimpSession session = this.sessions.get(sessionId); 120 if (session != null) { 121 String subscriptionId = accessor.getSubscriptionId(); 122 session.removeSubscription(subscriptionId); 123 } 124 } 125 } 126 127 @Override 128 public boolean supportsSourceType(Class<?> sourceType) { 129 return true; 130 } 131 132 133 // SimpUserRegistry methods 134 135 @Override 136 public SimpUser getUser(String userName) { 137 return this.users.get(userName); 138 } 139 140 @Override 141 public Set<SimpUser> getUsers() { 142 return new HashSet<SimpUser>(this.users.values()); 143 } 144 145 @Override 146 public int getUserCount() { 147 return this.users.size(); 148 } 149 150 public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) { 151 Set<SimpSubscription> result = new HashSet<SimpSubscription>(); 152 for (LocalSimpSession session : this.sessions.values()) { 153 for (SimpSubscription subscription : session.subscriptions.values()) { 154 if (matcher.match(subscription)) { 155 result.add(subscription); 156 } 157 } 158 } 159 return result; 160 } 161 162 163 @Override 164 public String toString() { 165 return "users=" + this.users; 166 } 167 168 169 private static class LocalSimpUser implements SimpUser { 170 171 private final String name; 172 173 private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<String, SimpSession>(1); 174 175 public LocalSimpUser(String userName) { 176 Assert.notNull(userName, "User name must not be null"); 177 this.name = userName; 178 } 179 180 @Override 181 public String getName() { 182 return this.name; 183 } 184 185 @Override 186 public boolean hasSessions() { 187 return !this.userSessions.isEmpty(); 188 } 189 190 @Override 191 public SimpSession getSession(String sessionId) { 192 return (sessionId != null ? this.userSessions.get(sessionId) : null); 193 } 194 195 @Override 196 public Set<SimpSession> getSessions() { 197 return new HashSet<SimpSession>(this.userSessions.values()); 198 } 199 200 void addSession(SimpSession session) { 201 this.userSessions.put(session.getId(), session); 202 } 203 204 void removeSession(String sessionId) { 205 this.userSessions.remove(sessionId); 206 } 207 208 @Override 209 public boolean equals(Object other) { 210 return (this == other || 211 (other instanceof SimpUser && this.name.equals(((SimpUser) other).getName()))); 212 } 213 214 @Override 215 public int hashCode() { 216 return this.name.hashCode(); 217 } 218 219 @Override 220 public String toString() { 221 return "name=" + this.name + ", sessions=" + this.userSessions; 222 } 223 } 224 225 226 private static class LocalSimpSession implements SimpSession { 227 228 private final String id; 229 230 private final LocalSimpUser user; 231 232 private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<String, SimpSubscription>(4); 233 234 public LocalSimpSession(String id, LocalSimpUser user) { 235 Assert.notNull(id, "Id must not be null"); 236 Assert.notNull(user, "User must not be null"); 237 this.id = id; 238 this.user = user; 239 } 240 241 @Override 242 public String getId() { 243 return this.id; 244 } 245 246 @Override 247 public LocalSimpUser getUser() { 248 return this.user; 249 } 250 251 @Override 252 public Set<SimpSubscription> getSubscriptions() { 253 return new HashSet<SimpSubscription>(this.subscriptions.values()); 254 } 255 256 void addSubscription(String id, String destination) { 257 this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this)); 258 } 259 260 void removeSubscription(String id) { 261 this.subscriptions.remove(id); 262 } 263 264 @Override 265 public boolean equals(Object other) { 266 return (this == other || 267 (other instanceof SimpSubscription && this.id.equals(((SimpSubscription) other).getId()))); 268 } 269 270 @Override 271 public int hashCode() { 272 return this.id.hashCode(); 273 } 274 275 @Override 276 public String toString() { 277 return "id=" + this.id + ", subscriptions=" + this.subscriptions; 278 } 279 } 280 281 282 private static class LocalSimpSubscription implements SimpSubscription { 283 284 private final String id; 285 286 private final LocalSimpSession session; 287 288 private final String destination; 289 290 public LocalSimpSubscription(String id, String destination, LocalSimpSession session) { 291 Assert.notNull(id, "Id must not be null"); 292 Assert.hasText(destination, "Destination must not be empty"); 293 Assert.notNull(session, "Session must not be null"); 294 this.id = id; 295 this.destination = destination; 296 this.session = session; 297 } 298 299 @Override 300 public String getId() { 301 return this.id; 302 } 303 304 @Override 305 public LocalSimpSession getSession() { 306 return this.session; 307 } 308 309 @Override 310 public String getDestination() { 311 return this.destination; 312 } 313 314 @Override 315 public boolean equals(Object other) { 316 if (this == other) { 317 return true; 318 } 319 if (!(other instanceof SimpSubscription)) { 320 return false; 321 } 322 SimpSubscription otherSubscription = (SimpSubscription) other; 323 return (this.id.equals(otherSubscription.getId()) && 324 getSession().getId().equals(otherSubscription.getSession().getId())); 325 } 326 327 @Override 328 public int hashCode() { 329 return this.id.hashCode() * 31 + getSession().getId().hashCode(); 330 } 331 332 @Override 333 public String toString() { 334 return "destination=" + this.destination; 335 } 336 } 337 338}