001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.transport; 018 019import java.io.IOException; 020import java.security.Principal; 021import java.util.ArrayList; 022import java.util.Arrays; 023import java.util.Collection; 024import java.util.Collections; 025import java.util.EnumMap; 026import java.util.HashMap; 027import java.util.List; 028import java.util.Map; 029import java.util.concurrent.ConcurrentHashMap; 030import java.util.concurrent.ScheduledFuture; 031 032import org.springframework.context.Lifecycle; 033import org.springframework.http.HttpMethod; 034import org.springframework.http.HttpStatus; 035import org.springframework.http.server.ServerHttpRequest; 036import org.springframework.http.server.ServerHttpResponse; 037import org.springframework.http.server.ServletServerHttpResponse; 038import org.springframework.lang.Nullable; 039import org.springframework.scheduling.TaskScheduler; 040import org.springframework.util.Assert; 041import org.springframework.util.ClassUtils; 042import org.springframework.util.CollectionUtils; 043import org.springframework.web.socket.WebSocketHandler; 044import org.springframework.web.socket.server.HandshakeFailureException; 045import org.springframework.web.socket.server.HandshakeHandler; 046import org.springframework.web.socket.server.HandshakeInterceptor; 047import org.springframework.web.socket.server.support.HandshakeInterceptorChain; 048import org.springframework.web.socket.sockjs.SockJsException; 049import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; 050import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; 051import org.springframework.web.socket.sockjs.support.AbstractSockJsService; 052 053/** 054 * A basic implementation of {@link org.springframework.web.socket.sockjs.SockJsService} 055 * with support for SPI-based transport handling and session management. 056 * 057 * <p>Based on the {@link TransportHandler} SPI. {@code TransportHandlers} may 058 * additionally implement the {@link SockJsSessionFactory} and {@link HandshakeHandler} interfaces. 059 * 060 * <p>See the {@link AbstractSockJsService} base class for important details on request mapping. 061 * 062 * @author Rossen Stoyanchev 063 * @author Juergen Hoeller 064 * @author Sebastien Deleuze 065 * @since 4.0 066 */ 067public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig, Lifecycle { 068 069 private static final boolean jackson2Present = ClassUtils.isPresent( 070 "com.fasterxml.jackson.databind.ObjectMapper", TransportHandlingSockJsService.class.getClassLoader()); 071 072 073 private final Map<TransportType, TransportHandler> handlers = new EnumMap<>(TransportType.class); 074 075 @Nullable 076 private SockJsMessageCodec messageCodec; 077 078 private final List<HandshakeInterceptor> interceptors = new ArrayList<>(); 079 080 private final Map<String, SockJsSession> sessions = new ConcurrentHashMap<>(); 081 082 @Nullable 083 private ScheduledFuture<?> sessionCleanupTask; 084 085 private volatile boolean running; 086 087 088 /** 089 * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types. 090 * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions; 091 * the provided TaskScheduler should be declared as a Spring bean to ensure it gets 092 * initialized at start-up and shuts down when the application stops 093 * @param handlers one or more {@link TransportHandler} implementations to use 094 */ 095 public TransportHandlingSockJsService(TaskScheduler scheduler, TransportHandler... handlers) { 096 this(scheduler, Arrays.asList(handlers)); 097 } 098 099 /** 100 * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types. 101 * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions; 102 * the provided TaskScheduler should be declared as a Spring bean to ensure it gets 103 * initialized at start-up and shuts down when the application stops 104 * @param handlers one or more {@link TransportHandler} implementations to use 105 */ 106 public TransportHandlingSockJsService(TaskScheduler scheduler, Collection<TransportHandler> handlers) { 107 super(scheduler); 108 109 if (CollectionUtils.isEmpty(handlers)) { 110 logger.warn("No transport handlers specified for TransportHandlingSockJsService"); 111 } 112 else { 113 for (TransportHandler handler : handlers) { 114 handler.initialize(this); 115 this.handlers.put(handler.getTransportType(), handler); 116 } 117 } 118 119 if (jackson2Present) { 120 this.messageCodec = new Jackson2SockJsMessageCodec(); 121 } 122 } 123 124 125 /** 126 * Return the registered handlers per transport type. 127 */ 128 public Map<TransportType, TransportHandler> getTransportHandlers() { 129 return Collections.unmodifiableMap(this.handlers); 130 } 131 132 /** 133 * The codec to use for encoding and decoding SockJS messages. 134 */ 135 public void setMessageCodec(SockJsMessageCodec messageCodec) { 136 this.messageCodec = messageCodec; 137 } 138 139 @Override 140 public SockJsMessageCodec getMessageCodec() { 141 Assert.state(this.messageCodec != null, "A SockJsMessageCodec is required but not available: " + 142 "Add Jackson to the classpath, or configure a custom SockJsMessageCodec."); 143 return this.messageCodec; 144 } 145 146 /** 147 * Configure one or more WebSocket handshake request interceptors. 148 */ 149 public void setHandshakeInterceptors(@Nullable List<HandshakeInterceptor> interceptors) { 150 this.interceptors.clear(); 151 if (interceptors != null) { 152 this.interceptors.addAll(interceptors); 153 } 154 } 155 156 /** 157 * Return the configured WebSocket handshake request interceptors. 158 */ 159 public List<HandshakeInterceptor> getHandshakeInterceptors() { 160 return this.interceptors; 161 } 162 163 164 @Override 165 public void start() { 166 if (!isRunning()) { 167 this.running = true; 168 for (TransportHandler handler : this.handlers.values()) { 169 if (handler instanceof Lifecycle) { 170 ((Lifecycle) handler).start(); 171 } 172 } 173 } 174 } 175 176 @Override 177 public void stop() { 178 if (isRunning()) { 179 this.running = false; 180 for (TransportHandler handler : this.handlers.values()) { 181 if (handler instanceof Lifecycle) { 182 ((Lifecycle) handler).stop(); 183 } 184 } 185 } 186 } 187 188 @Override 189 public boolean isRunning() { 190 return this.running; 191 } 192 193 194 @Override 195 protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, 196 WebSocketHandler handler) throws IOException { 197 198 TransportHandler transportHandler = this.handlers.get(TransportType.WEBSOCKET); 199 if (!(transportHandler instanceof HandshakeHandler)) { 200 logger.error("No handler configured for raw WebSocket messages"); 201 response.setStatusCode(HttpStatus.NOT_FOUND); 202 return; 203 } 204 205 HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler); 206 HandshakeFailureException failure = null; 207 208 try { 209 Map<String, Object> attributes = new HashMap<>(); 210 if (!chain.applyBeforeHandshake(request, response, attributes)) { 211 return; 212 } 213 ((HandshakeHandler) transportHandler).doHandshake(request, response, handler, attributes); 214 chain.applyAfterHandshake(request, response, null); 215 } 216 catch (HandshakeFailureException ex) { 217 failure = ex; 218 } 219 catch (Exception ex) { 220 failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex); 221 } 222 finally { 223 if (failure != null) { 224 chain.applyAfterHandshake(request, response, failure); 225 throw failure; 226 } 227 } 228 } 229 230 @Override 231 protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, 232 WebSocketHandler handler, String sessionId, String transport) throws SockJsException { 233 234 TransportType transportType = TransportType.fromValue(transport); 235 if (transportType == null) { 236 if (logger.isWarnEnabled()) { 237 logger.warn("Unknown transport type for " + request.getURI()); 238 } 239 response.setStatusCode(HttpStatus.NOT_FOUND); 240 return; 241 } 242 243 TransportHandler transportHandler = this.handlers.get(transportType); 244 if (transportHandler == null) { 245 if (logger.isWarnEnabled()) { 246 logger.warn("No TransportHandler for " + request.getURI()); 247 } 248 response.setStatusCode(HttpStatus.NOT_FOUND); 249 return; 250 } 251 252 SockJsException failure = null; 253 HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler); 254 255 try { 256 HttpMethod supportedMethod = transportType.getHttpMethod(); 257 if (supportedMethod != request.getMethod()) { 258 if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) { 259 if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) { 260 response.setStatusCode(HttpStatus.NO_CONTENT); 261 addCacheHeaders(response); 262 } 263 } 264 else if (transportType.supportsCors()) { 265 sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS); 266 } 267 else { 268 sendMethodNotAllowed(response, supportedMethod); 269 } 270 return; 271 } 272 273 SockJsSession session = this.sessions.get(sessionId); 274 boolean isNewSession = false; 275 if (session == null) { 276 if (transportHandler instanceof SockJsSessionFactory) { 277 Map<String, Object> attributes = new HashMap<>(); 278 if (!chain.applyBeforeHandshake(request, response, attributes)) { 279 return; 280 } 281 SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; 282 session = createSockJsSession(sessionId, sessionFactory, handler, attributes); 283 isNewSession = true; 284 } 285 else { 286 response.setStatusCode(HttpStatus.NOT_FOUND); 287 if (logger.isDebugEnabled()) { 288 logger.debug("Session not found, sessionId=" + sessionId + 289 ". The session may have been closed " + 290 "(e.g. missed heart-beat) while a message was coming in."); 291 } 292 return; 293 } 294 } 295 else { 296 Principal principal = session.getPrincipal(); 297 if (principal != null && !principal.equals(request.getPrincipal())) { 298 logger.debug("The user for the session does not match the user for the request."); 299 response.setStatusCode(HttpStatus.NOT_FOUND); 300 return; 301 } 302 if (!transportHandler.checkSessionType(session)) { 303 logger.debug("Session type does not match the transport type for the request."); 304 response.setStatusCode(HttpStatus.NOT_FOUND); 305 return; 306 } 307 } 308 309 if (transportType.sendsNoCacheInstruction()) { 310 addNoCacheHeaders(response); 311 } 312 if (transportType.supportsCors() && !checkOrigin(request, response)) { 313 return; 314 } 315 316 transportHandler.handleRequest(request, response, handler, session); 317 318 if (isNewSession && (response instanceof ServletServerHttpResponse)) { 319 int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus(); 320 if (HttpStatus.valueOf(status).is4xxClientError()) { 321 this.sessions.remove(sessionId); 322 } 323 } 324 325 chain.applyAfterHandshake(request, response, null); 326 } 327 catch (SockJsException ex) { 328 failure = ex; 329 } 330 catch (Exception ex) { 331 failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex); 332 } 333 finally { 334 if (failure != null) { 335 chain.applyAfterHandshake(request, response, failure); 336 throw failure; 337 } 338 } 339 } 340 341 @Override 342 protected boolean validateRequest(String serverId, String sessionId, String transport) { 343 if (!super.validateRequest(serverId, sessionId, transport)) { 344 return false; 345 } 346 347 if (!this.allowedOrigins.contains("*")) { 348 TransportType transportType = TransportType.fromValue(transport); 349 if (transportType == null || !transportType.supportsOrigin()) { 350 if (logger.isWarnEnabled()) { 351 logger.warn("Origin check enabled but transport '" + transport + "' does not support it."); 352 } 353 return false; 354 } 355 } 356 357 return true; 358 } 359 360 private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, 361 WebSocketHandler handler, Map<String, Object> attributes) { 362 363 SockJsSession session = this.sessions.get(sessionId); 364 if (session != null) { 365 return session; 366 } 367 if (this.sessionCleanupTask == null) { 368 scheduleSessionTask(); 369 } 370 session = sessionFactory.createSession(sessionId, handler, attributes); 371 this.sessions.put(sessionId, session); 372 return session; 373 } 374 375 private void scheduleSessionTask() { 376 synchronized (this.sessions) { 377 if (this.sessionCleanupTask != null) { 378 return; 379 } 380 this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(() -> { 381 List<String> removedIds = new ArrayList<>(); 382 for (SockJsSession session : this.sessions.values()) { 383 try { 384 if (session.getTimeSinceLastActive() > getDisconnectDelay()) { 385 this.sessions.remove(session.getId()); 386 removedIds.add(session.getId()); 387 session.close(); 388 } 389 } 390 catch (Throwable ex) { 391 // Could be part of normal workflow (e.g. browser tab closed) 392 logger.debug("Failed to close " + session, ex); 393 } 394 } 395 if (logger.isDebugEnabled() && !removedIds.isEmpty()) { 396 logger.debug("Closed " + removedIds.size() + " sessions: " + removedIds); 397 } 398 }, getDisconnectDelay()); 399 } 400 } 401 402}