001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.messaging; 018 019import java.security.Principal; 020import java.util.HashSet; 021import java.util.Map; 022import java.util.Set; 023import java.util.concurrent.ConcurrentHashMap; 024 025import org.springframework.context.ApplicationEvent; 026import org.springframework.context.event.SmartApplicationListener; 027import org.springframework.core.Ordered; 028import org.springframework.lang.Nullable; 029import org.springframework.messaging.Message; 030import org.springframework.messaging.MessageHeaders; 031import org.springframework.messaging.simp.SimpMessageHeaderAccessor; 032import org.springframework.messaging.simp.user.DestinationUserNameProvider; 033import org.springframework.messaging.simp.user.SimpSession; 034import org.springframework.messaging.simp.user.SimpSubscription; 035import org.springframework.messaging.simp.user.SimpSubscriptionMatcher; 036import org.springframework.messaging.simp.user.SimpUser; 037import org.springframework.messaging.simp.user.SimpUserRegistry; 038import org.springframework.util.Assert; 039 040/** 041 * A default implementation of {@link SimpUserRegistry} that relies on 042 * {@link AbstractSubProtocolEvent} application context events to keep 043 * track of connected users and their subscriptions. 044 * 045 * @author Rossen Stoyanchev 046 * @since 4.2 047 */ 048public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener { 049 050 private int order = Ordered.LOWEST_PRECEDENCE; 051 052 /* Primary lookup that holds all users and their sessions */ 053 private final Map<String, LocalSimpUser> users = new ConcurrentHashMap<>(); 054 055 /* Secondary lookup across all sessions by id */ 056 private final Map<String, LocalSimpSession> sessions = new ConcurrentHashMap<>(); 057 058 private final Object sessionLock = new Object(); 059 060 061 /** 062 * Specify the order value for this registry. 063 * <p>Default is {@link Ordered#LOWEST_PRECEDENCE}. 064 * @since 5.0.8 065 */ 066 public void setOrder(int order) { 067 this.order = order; 068 } 069 070 @Override 071 public int getOrder() { 072 return this.order; 073 } 074 075 076 // SmartApplicationListener methods 077 078 @Override 079 public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) { 080 return AbstractSubProtocolEvent.class.isAssignableFrom(eventType); 081 } 082 083 @Override 084 public void onApplicationEvent(ApplicationEvent event) { 085 AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event; 086 Message<?> message = subProtocolEvent.getMessage(); 087 MessageHeaders headers = message.getHeaders(); 088 089 String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); 090 Assert.state(sessionId != null, "No session id"); 091 092 if (event instanceof SessionSubscribeEvent) { 093 LocalSimpSession session = this.sessions.get(sessionId); 094 if (session != null) { 095 String id = SimpMessageHeaderAccessor.getSubscriptionId(headers); 096 String destination = SimpMessageHeaderAccessor.getDestination(headers); 097 if (id != null && destination != null) { 098 session.addSubscription(id, destination); 099 } 100 } 101 } 102 else if (event instanceof SessionConnectedEvent) { 103 Principal user = subProtocolEvent.getUser(); 104 if (user == null) { 105 return; 106 } 107 String name = user.getName(); 108 if (user instanceof DestinationUserNameProvider) { 109 name = ((DestinationUserNameProvider) user).getDestinationUserName(); 110 } 111 synchronized (this.sessionLock) { 112 LocalSimpUser simpUser = this.users.get(name); 113 if (simpUser == null) { 114 simpUser = new LocalSimpUser(name); 115 this.users.put(name, simpUser); 116 } 117 LocalSimpSession session = new LocalSimpSession(sessionId, simpUser); 118 simpUser.addSession(session); 119 this.sessions.put(sessionId, session); 120 } 121 } 122 else if (event instanceof SessionDisconnectEvent) { 123 synchronized (this.sessionLock) { 124 LocalSimpSession session = this.sessions.remove(sessionId); 125 if (session != null) { 126 LocalSimpUser user = session.getUser(); 127 user.removeSession(sessionId); 128 if (!user.hasSessions()) { 129 this.users.remove(user.getName()); 130 } 131 } 132 } 133 } 134 else if (event instanceof SessionUnsubscribeEvent) { 135 LocalSimpSession session = this.sessions.get(sessionId); 136 if (session != null) { 137 String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); 138 if (subscriptionId != null) { 139 session.removeSubscription(subscriptionId); 140 } 141 } 142 } 143 } 144 145 @Override 146 public boolean supportsSourceType(@Nullable Class<?> sourceType) { 147 return true; 148 } 149 150 151 // SimpUserRegistry methods 152 153 @Override 154 @Nullable 155 public SimpUser getUser(String userName) { 156 return this.users.get(userName); 157 } 158 159 @Override 160 public Set<SimpUser> getUsers() { 161 return new HashSet<>(this.users.values()); 162 } 163 164 @Override 165 public int getUserCount() { 166 return this.users.size(); 167 } 168 169 @Override 170 public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) { 171 Set<SimpSubscription> result = new HashSet<>(); 172 for (LocalSimpSession session : this.sessions.values()) { 173 for (SimpSubscription subscription : session.subscriptions.values()) { 174 if (matcher.match(subscription)) { 175 result.add(subscription); 176 } 177 } 178 } 179 return result; 180 } 181 182 183 @Override 184 public String toString() { 185 return "users=" + this.users; 186 } 187 188 189 private static class LocalSimpUser implements SimpUser { 190 191 private final String name; 192 193 private final Map<String, SimpSession> userSessions = new ConcurrentHashMap<>(1); 194 195 public LocalSimpUser(String userName) { 196 Assert.notNull(userName, "User name must not be null"); 197 this.name = userName; 198 } 199 200 @Override 201 public String getName() { 202 return this.name; 203 } 204 205 @Override 206 public boolean hasSessions() { 207 return !this.userSessions.isEmpty(); 208 } 209 210 @Override 211 @Nullable 212 public SimpSession getSession(@Nullable String sessionId) { 213 return (sessionId != null ? this.userSessions.get(sessionId) : null); 214 } 215 216 @Override 217 public Set<SimpSession> getSessions() { 218 return new HashSet<>(this.userSessions.values()); 219 } 220 221 void addSession(SimpSession session) { 222 this.userSessions.put(session.getId(), session); 223 } 224 225 void removeSession(String sessionId) { 226 this.userSessions.remove(sessionId); 227 } 228 229 @Override 230 public boolean equals(@Nullable Object other) { 231 return (this == other || 232 (other instanceof SimpUser && getName().equals(((SimpUser) other).getName()))); 233 } 234 235 @Override 236 public int hashCode() { 237 return getName().hashCode(); 238 } 239 240 @Override 241 public String toString() { 242 return "name=" + getName() + ", sessions=" + this.userSessions; 243 } 244 } 245 246 247 private static class LocalSimpSession implements SimpSession { 248 249 private final String id; 250 251 private final LocalSimpUser user; 252 253 private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<>(4); 254 255 public LocalSimpSession(String id, LocalSimpUser user) { 256 Assert.notNull(id, "Id must not be null"); 257 Assert.notNull(user, "User must not be null"); 258 this.id = id; 259 this.user = user; 260 } 261 262 @Override 263 public String getId() { 264 return this.id; 265 } 266 267 @Override 268 public LocalSimpUser getUser() { 269 return this.user; 270 } 271 272 @Override 273 public Set<SimpSubscription> getSubscriptions() { 274 return new HashSet<>(this.subscriptions.values()); 275 } 276 277 void addSubscription(String id, String destination) { 278 this.subscriptions.put(id, new LocalSimpSubscription(id, destination, this)); 279 } 280 281 void removeSubscription(String id) { 282 this.subscriptions.remove(id); 283 } 284 285 @Override 286 public boolean equals(@Nullable Object other) { 287 return (this == other || 288 (other instanceof SimpSubscription && getId().equals(((SimpSubscription) other).getId()))); 289 } 290 291 @Override 292 public int hashCode() { 293 return getId().hashCode(); 294 } 295 296 @Override 297 public String toString() { 298 return "id=" + getId() + ", subscriptions=" + this.subscriptions; 299 } 300 } 301 302 303 private static class LocalSimpSubscription implements SimpSubscription { 304 305 private final String id; 306 307 private final LocalSimpSession session; 308 309 private final String destination; 310 311 public LocalSimpSubscription(String id, String destination, LocalSimpSession session) { 312 Assert.notNull(id, "Id must not be null"); 313 Assert.hasText(destination, "Destination must not be empty"); 314 Assert.notNull(session, "Session must not be null"); 315 this.id = id; 316 this.destination = destination; 317 this.session = session; 318 } 319 320 @Override 321 public String getId() { 322 return this.id; 323 } 324 325 @Override 326 public LocalSimpSession getSession() { 327 return this.session; 328 } 329 330 @Override 331 public String getDestination() { 332 return this.destination; 333 } 334 335 @Override 336 public boolean equals(@Nullable Object other) { 337 if (this == other) { 338 return true; 339 } 340 if (!(other instanceof SimpSubscription)) { 341 return false; 342 } 343 SimpSubscription otherSubscription = (SimpSubscription) other; 344 return (getId().equals(otherSubscription.getId()) && 345 getSession().getId().equals(otherSubscription.getSession().getId())); 346 } 347 348 @Override 349 public int hashCode() { 350 return getId().hashCode() * 31 + getSession().getId().hashCode(); 351 } 352 353 @Override 354 public String toString() { 355 return "destination=" + this.destination; 356 } 357 } 358 359}