001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.transport; 018 019import java.io.IOException; 020import java.util.ArrayList; 021import java.util.Arrays; 022import java.util.Collection; 023import java.util.Collections; 024import java.util.EnumMap; 025import java.util.HashMap; 026import java.util.List; 027import java.util.Map; 028import java.util.concurrent.ConcurrentHashMap; 029import java.util.concurrent.ScheduledFuture; 030 031import org.springframework.context.Lifecycle; 032import org.springframework.http.HttpMethod; 033import org.springframework.http.HttpStatus; 034import org.springframework.http.server.ServerHttpRequest; 035import org.springframework.http.server.ServerHttpResponse; 036import org.springframework.scheduling.TaskScheduler; 037import org.springframework.util.Assert; 038import org.springframework.util.ClassUtils; 039import org.springframework.util.CollectionUtils; 040import org.springframework.web.socket.WebSocketHandler; 041import org.springframework.web.socket.server.HandshakeFailureException; 042import org.springframework.web.socket.server.HandshakeHandler; 043import org.springframework.web.socket.server.HandshakeInterceptor; 044import org.springframework.web.socket.server.support.HandshakeInterceptorChain; 045import org.springframework.web.socket.sockjs.SockJsException; 046import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; 047import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; 048import org.springframework.web.socket.sockjs.support.AbstractSockJsService; 049 050/** 051 * A basic implementation of {@link org.springframework.web.socket.sockjs.SockJsService} 052 * with support for SPI-based transport handling and session management. 053 * 054 * <p>Based on the {@link TransportHandler} SPI. {@link TransportHandler}s may additionally 055 * implement the {@link SockJsSessionFactory} and {@link HandshakeHandler} interfaces. 056 * 057 * <p>See the {@link AbstractSockJsService} base class for important details on request mapping. 058 * 059 * @author Rossen Stoyanchev 060 * @author Juergen Hoeller 061 * @author Sebastien Deleuze 062 * @since 4.0 063 */ 064public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig, Lifecycle { 065 066 private static final boolean jackson2Present = ClassUtils.isPresent( 067 "com.fasterxml.jackson.databind.ObjectMapper", TransportHandlingSockJsService.class.getClassLoader()); 068 069 070 private final Map<TransportType, TransportHandler> handlers = 071 new EnumMap<TransportType, TransportHandler>(TransportType.class); 072 073 private SockJsMessageCodec messageCodec; 074 075 private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); 076 077 private final Map<String, SockJsSession> sessions = new ConcurrentHashMap<String, SockJsSession>(); 078 079 private ScheduledFuture<?> sessionCleanupTask; 080 081 private volatile boolean running; 082 083 084 /** 085 * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types. 086 * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions; 087 * the provided TaskScheduler should be declared as a Spring bean to ensure it gets 088 * initialized at start-up and shuts down when the application stops 089 * @param handlers one or more {@link TransportHandler} implementations to use 090 */ 091 public TransportHandlingSockJsService(TaskScheduler scheduler, TransportHandler... handlers) { 092 this(scheduler, Arrays.asList(handlers)); 093 } 094 095 /** 096 * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types. 097 * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions; 098 * the provided TaskScheduler should be declared as a Spring bean to ensure it gets 099 * initialized at start-up and shuts down when the application stops 100 * @param handlers one or more {@link TransportHandler} implementations to use 101 */ 102 public TransportHandlingSockJsService(TaskScheduler scheduler, Collection<TransportHandler> handlers) { 103 super(scheduler); 104 105 if (CollectionUtils.isEmpty(handlers)) { 106 logger.warn("No transport handlers specified for TransportHandlingSockJsService"); 107 } 108 else { 109 for (TransportHandler handler : handlers) { 110 handler.initialize(this); 111 this.handlers.put(handler.getTransportType(), handler); 112 } 113 } 114 115 if (jackson2Present) { 116 this.messageCodec = new Jackson2SockJsMessageCodec(); 117 } 118 } 119 120 121 /** 122 * Return the registered handlers per transport type. 123 */ 124 public Map<TransportType, TransportHandler> getTransportHandlers() { 125 return Collections.unmodifiableMap(this.handlers); 126 } 127 128 /** 129 * The codec to use for encoding and decoding SockJS messages. 130 */ 131 public void setMessageCodec(SockJsMessageCodec messageCodec) { 132 this.messageCodec = messageCodec; 133 } 134 135 public SockJsMessageCodec getMessageCodec() { 136 Assert.state(this.messageCodec != null, "A SockJsMessageCodec is required but not available: " + 137 "Add Jackson to the classpath, or configure a custom SockJsMessageCodec."); 138 return this.messageCodec; 139 } 140 141 /** 142 * Configure one or more WebSocket handshake request interceptors. 143 */ 144 public void setHandshakeInterceptors(List<HandshakeInterceptor> interceptors) { 145 this.interceptors.clear(); 146 if (interceptors != null) { 147 this.interceptors.addAll(interceptors); 148 } 149 } 150 151 /** 152 * Return the configured WebSocket handshake request interceptors. 153 */ 154 public List<HandshakeInterceptor> getHandshakeInterceptors() { 155 return this.interceptors; 156 } 157 158 159 @Override 160 public void start() { 161 if (!isRunning()) { 162 this.running = true; 163 for (TransportHandler handler : this.handlers.values()) { 164 if (handler instanceof Lifecycle) { 165 ((Lifecycle) handler).start(); 166 } 167 } 168 } 169 } 170 171 @Override 172 public void stop() { 173 if (isRunning()) { 174 this.running = false; 175 for (TransportHandler handler : this.handlers.values()) { 176 if (handler instanceof Lifecycle) { 177 ((Lifecycle) handler).stop(); 178 } 179 } 180 } 181 } 182 183 @Override 184 public boolean isRunning() { 185 return this.running; 186 } 187 188 189 @Override 190 protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, 191 WebSocketHandler handler) throws IOException { 192 193 TransportHandler transportHandler = this.handlers.get(TransportType.WEBSOCKET); 194 if (!(transportHandler instanceof HandshakeHandler)) { 195 logger.error("No handler configured for raw WebSocket messages"); 196 response.setStatusCode(HttpStatus.NOT_FOUND); 197 return; 198 } 199 200 HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler); 201 HandshakeFailureException failure = null; 202 203 try { 204 Map<String, Object> attributes = new HashMap<String, Object>(); 205 if (!chain.applyBeforeHandshake(request, response, attributes)) { 206 return; 207 } 208 ((HandshakeHandler) transportHandler).doHandshake(request, response, handler, attributes); 209 chain.applyAfterHandshake(request, response, null); 210 } 211 catch (HandshakeFailureException ex) { 212 failure = ex; 213 } 214 catch (Throwable ex) { 215 failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex); 216 } 217 finally { 218 if (failure != null) { 219 chain.applyAfterHandshake(request, response, failure); 220 throw failure; 221 } 222 } 223 } 224 225 @Override 226 protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, 227 WebSocketHandler handler, String sessionId, String transport) throws SockJsException { 228 229 TransportType transportType = TransportType.fromValue(transport); 230 if (transportType == null) { 231 if (logger.isWarnEnabled()) { 232 logger.warn("Unknown transport type for " + request.getURI()); 233 } 234 response.setStatusCode(HttpStatus.NOT_FOUND); 235 return; 236 } 237 238 TransportHandler transportHandler = this.handlers.get(transportType); 239 if (transportHandler == null) { 240 if (logger.isWarnEnabled()) { 241 logger.warn("No TransportHandler for " + request.getURI()); 242 } 243 response.setStatusCode(HttpStatus.NOT_FOUND); 244 return; 245 } 246 247 SockJsException failure = null; 248 HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler); 249 250 try { 251 HttpMethod supportedMethod = transportType.getHttpMethod(); 252 if (supportedMethod != request.getMethod()) { 253 if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) { 254 if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) { 255 response.setStatusCode(HttpStatus.NO_CONTENT); 256 addCacheHeaders(response); 257 } 258 } 259 else if (transportType.supportsCors()) { 260 sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS); 261 } 262 else { 263 sendMethodNotAllowed(response, supportedMethod); 264 } 265 return; 266 } 267 268 SockJsSession session = this.sessions.get(sessionId); 269 if (session == null) { 270 if (transportHandler instanceof SockJsSessionFactory) { 271 Map<String, Object> attributes = new HashMap<String, Object>(); 272 if (!chain.applyBeforeHandshake(request, response, attributes)) { 273 return; 274 } 275 SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; 276 session = createSockJsSession(sessionId, sessionFactory, handler, attributes); 277 } 278 else { 279 response.setStatusCode(HttpStatus.NOT_FOUND); 280 if (logger.isDebugEnabled()) { 281 logger.debug("Session not found, sessionId=" + sessionId + 282 ". The session may have been closed " + 283 "(e.g. missed heart-beat) while a message was coming in."); 284 } 285 return; 286 } 287 } 288 else { 289 if (session.getPrincipal() != null) { 290 if (!session.getPrincipal().equals(request.getPrincipal())) { 291 logger.debug("The user for the session does not match the user for the request."); 292 response.setStatusCode(HttpStatus.NOT_FOUND); 293 return; 294 } 295 } 296 if (!transportHandler.checkSessionType(session)) { 297 logger.debug("Session type does not match the transport type for the request."); 298 response.setStatusCode(HttpStatus.NOT_FOUND); 299 return; 300 } 301 } 302 303 if (transportType.sendsNoCacheInstruction()) { 304 addNoCacheHeaders(response); 305 } 306 307 if (transportType.supportsCors()) { 308 if (!checkOrigin(request, response)) { 309 return; 310 } 311 } 312 313 314 transportHandler.handleRequest(request, response, handler, session); 315 316 317 chain.applyAfterHandshake(request, response, null); 318 } 319 catch (SockJsException ex) { 320 failure = ex; 321 } 322 catch (Throwable ex) { 323 failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex); 324 } 325 finally { 326 if (failure != null) { 327 chain.applyAfterHandshake(request, response, failure); 328 throw failure; 329 } 330 } 331 } 332 333 @Override 334 protected boolean validateRequest(String serverId, String sessionId, String transport) { 335 if (!super.validateRequest(serverId, sessionId, transport)) { 336 return false; 337 } 338 339 if (!this.allowedOrigins.contains("*")) { 340 TransportType transportType = TransportType.fromValue(transport); 341 if (transportType == null || !transportType.supportsOrigin()) { 342 if (logger.isWarnEnabled()) { 343 logger.warn("Origin check enabled but transport '" + transport + "' does not support it."); 344 } 345 return false; 346 } 347 } 348 349 return true; 350 } 351 352 private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, 353 WebSocketHandler handler, Map<String, Object> attributes) { 354 355 SockJsSession session = this.sessions.get(sessionId); 356 if (session != null) { 357 return session; 358 } 359 if (this.sessionCleanupTask == null) { 360 scheduleSessionTask(); 361 } 362 session = sessionFactory.createSession(sessionId, handler, attributes); 363 this.sessions.put(sessionId, session); 364 return session; 365 } 366 367 private void scheduleSessionTask() { 368 synchronized (this.sessions) { 369 if (this.sessionCleanupTask != null) { 370 return; 371 } 372 this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(new Runnable() { 373 @Override 374 public void run() { 375 List<String> removedIds = new ArrayList<String>(); 376 for (SockJsSession session : sessions.values()) { 377 try { 378 if (session.getTimeSinceLastActive() > getDisconnectDelay()) { 379 sessions.remove(session.getId()); 380 removedIds.add(session.getId()); 381 session.close(); 382 } 383 } 384 catch (Throwable ex) { 385 // Could be part of normal workflow (e.g. browser tab closed) 386 logger.debug("Failed to close " + session, ex); 387 } 388 } 389 if (logger.isDebugEnabled() && !removedIds.isEmpty()) { 390 logger.debug("Closed " + removedIds.size() + " sessions: " + removedIds); 391 } 392 } 393 }, getDisconnectDelay()); 394 } 395 } 396 397}