001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.transport;
018
019import java.io.IOException;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.Arrays;
023import java.util.Collection;
024import java.util.Collections;
025import java.util.EnumMap;
026import java.util.HashMap;
027import java.util.List;
028import java.util.Map;
029import java.util.concurrent.ConcurrentHashMap;
030import java.util.concurrent.ScheduledFuture;
031
032import org.springframework.context.Lifecycle;
033import org.springframework.http.HttpMethod;
034import org.springframework.http.HttpStatus;
035import org.springframework.http.server.ServerHttpRequest;
036import org.springframework.http.server.ServerHttpResponse;
037import org.springframework.http.server.ServletServerHttpResponse;
038import org.springframework.lang.Nullable;
039import org.springframework.scheduling.TaskScheduler;
040import org.springframework.util.Assert;
041import org.springframework.util.ClassUtils;
042import org.springframework.util.CollectionUtils;
043import org.springframework.web.socket.WebSocketHandler;
044import org.springframework.web.socket.server.HandshakeFailureException;
045import org.springframework.web.socket.server.HandshakeHandler;
046import org.springframework.web.socket.server.HandshakeInterceptor;
047import org.springframework.web.socket.server.support.HandshakeInterceptorChain;
048import org.springframework.web.socket.sockjs.SockJsException;
049import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
050import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
051import org.springframework.web.socket.sockjs.support.AbstractSockJsService;
052
053/**
054 * A basic implementation of {@link org.springframework.web.socket.sockjs.SockJsService}
055 * with support for SPI-based transport handling and session management.
056 *
057 * <p>Based on the {@link TransportHandler} SPI. {@code TransportHandlers} may
058 * additionally implement the {@link SockJsSessionFactory} and {@link HandshakeHandler} interfaces.
059 *
060 * <p>See the {@link AbstractSockJsService} base class for important details on request mapping.
061 *
062 * @author Rossen Stoyanchev
063 * @author Juergen Hoeller
064 * @author Sebastien Deleuze
065 * @since 4.0
066 */
067public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig, Lifecycle {
068
069        private static final boolean jackson2Present = ClassUtils.isPresent(
070                        "com.fasterxml.jackson.databind.ObjectMapper", TransportHandlingSockJsService.class.getClassLoader());
071
072
073        private final Map<TransportType, TransportHandler> handlers = new EnumMap<>(TransportType.class);
074
075        @Nullable
076        private SockJsMessageCodec messageCodec;
077
078        private final List<HandshakeInterceptor> interceptors = new ArrayList<>();
079
080        private final Map<String, SockJsSession> sessions = new ConcurrentHashMap<>();
081
082        @Nullable
083        private ScheduledFuture<?> sessionCleanupTask;
084
085        private volatile boolean running;
086
087
088        /**
089         * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types.
090         * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions;
091         * the provided TaskScheduler should be declared as a Spring bean to ensure it gets
092         * initialized at start-up and shuts down when the application stops
093         * @param handlers one or more {@link TransportHandler} implementations to use
094         */
095        public TransportHandlingSockJsService(TaskScheduler scheduler, TransportHandler... handlers) {
096                this(scheduler, Arrays.asList(handlers));
097        }
098
099        /**
100         * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types.
101         * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions;
102         * the provided TaskScheduler should be declared as a Spring bean to ensure it gets
103         * initialized at start-up and shuts down when the application stops
104         * @param handlers one or more {@link TransportHandler} implementations to use
105         */
106        public TransportHandlingSockJsService(TaskScheduler scheduler, Collection<TransportHandler> handlers) {
107                super(scheduler);
108
109                if (CollectionUtils.isEmpty(handlers)) {
110                        logger.warn("No transport handlers specified for TransportHandlingSockJsService");
111                }
112                else {
113                        for (TransportHandler handler : handlers) {
114                                handler.initialize(this);
115                                this.handlers.put(handler.getTransportType(), handler);
116                        }
117                }
118
119                if (jackson2Present) {
120                        this.messageCodec = new Jackson2SockJsMessageCodec();
121                }
122        }
123
124
125        /**
126         * Return the registered handlers per transport type.
127         */
128        public Map<TransportType, TransportHandler> getTransportHandlers() {
129                return Collections.unmodifiableMap(this.handlers);
130        }
131
132        /**
133         * The codec to use for encoding and decoding SockJS messages.
134         */
135        public void setMessageCodec(SockJsMessageCodec messageCodec) {
136                this.messageCodec = messageCodec;
137        }
138
139        @Override
140        public SockJsMessageCodec getMessageCodec() {
141                Assert.state(this.messageCodec != null, "A SockJsMessageCodec is required but not available: " +
142                                "Add Jackson to the classpath, or configure a custom SockJsMessageCodec.");
143                return this.messageCodec;
144        }
145
146        /**
147         * Configure one or more WebSocket handshake request interceptors.
148         */
149        public void setHandshakeInterceptors(@Nullable List<HandshakeInterceptor> interceptors) {
150                this.interceptors.clear();
151                if (interceptors != null) {
152                        this.interceptors.addAll(interceptors);
153                }
154        }
155
156        /**
157         * Return the configured WebSocket handshake request interceptors.
158         */
159        public List<HandshakeInterceptor> getHandshakeInterceptors() {
160                return this.interceptors;
161        }
162
163
164        @Override
165        public void start() {
166                if (!isRunning()) {
167                        this.running = true;
168                        for (TransportHandler handler : this.handlers.values()) {
169                                if (handler instanceof Lifecycle) {
170                                        ((Lifecycle) handler).start();
171                                }
172                        }
173                }
174        }
175
176        @Override
177        public void stop() {
178                if (isRunning()) {
179                        this.running = false;
180                        for (TransportHandler handler : this.handlers.values()) {
181                                if (handler instanceof Lifecycle) {
182                                        ((Lifecycle) handler).stop();
183                                }
184                        }
185                }
186        }
187
188        @Override
189        public boolean isRunning() {
190                return this.running;
191        }
192
193
194        @Override
195        protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response,
196                        WebSocketHandler handler) throws IOException {
197
198                TransportHandler transportHandler = this.handlers.get(TransportType.WEBSOCKET);
199                if (!(transportHandler instanceof HandshakeHandler)) {
200                        logger.error("No handler configured for raw WebSocket messages");
201                        response.setStatusCode(HttpStatus.NOT_FOUND);
202                        return;
203                }
204
205                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
206                HandshakeFailureException failure = null;
207
208                try {
209                        Map<String, Object> attributes = new HashMap<>();
210                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
211                                return;
212                        }
213                        ((HandshakeHandler) transportHandler).doHandshake(request, response, handler, attributes);
214                        chain.applyAfterHandshake(request, response, null);
215                }
216                catch (HandshakeFailureException ex) {
217                        failure = ex;
218                }
219                catch (Exception ex) {
220                        failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
221                }
222                finally {
223                        if (failure != null) {
224                                chain.applyAfterHandshake(request, response, failure);
225                                throw failure;
226                        }
227                }
228        }
229
230        @Override
231        protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
232                        WebSocketHandler handler, String sessionId, String transport) throws SockJsException {
233
234                TransportType transportType = TransportType.fromValue(transport);
235                if (transportType == null) {
236                        if (logger.isWarnEnabled()) {
237                                logger.warn("Unknown transport type for " + request.getURI());
238                        }
239                        response.setStatusCode(HttpStatus.NOT_FOUND);
240                        return;
241                }
242
243                TransportHandler transportHandler = this.handlers.get(transportType);
244                if (transportHandler == null) {
245                        if (logger.isWarnEnabled()) {
246                                logger.warn("No TransportHandler for " + request.getURI());
247                        }
248                        response.setStatusCode(HttpStatus.NOT_FOUND);
249                        return;
250                }
251
252                SockJsException failure = null;
253                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
254
255                try {
256                        HttpMethod supportedMethod = transportType.getHttpMethod();
257                        if (supportedMethod != request.getMethod()) {
258                                if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) {
259                                        if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) {
260                                                response.setStatusCode(HttpStatus.NO_CONTENT);
261                                                addCacheHeaders(response);
262                                        }
263                                }
264                                else if (transportType.supportsCors()) {
265                                        sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
266                                }
267                                else {
268                                        sendMethodNotAllowed(response, supportedMethod);
269                                }
270                                return;
271                        }
272
273                        SockJsSession session = this.sessions.get(sessionId);
274                        boolean isNewSession = false;
275                        if (session == null) {
276                                if (transportHandler instanceof SockJsSessionFactory) {
277                                        Map<String, Object> attributes = new HashMap<>();
278                                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
279                                                return;
280                                        }
281                                        SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler;
282                                        session = createSockJsSession(sessionId, sessionFactory, handler, attributes);
283                                        isNewSession = true;
284                                }
285                                else {
286                                        response.setStatusCode(HttpStatus.NOT_FOUND);
287                                        if (logger.isDebugEnabled()) {
288                                                logger.debug("Session not found, sessionId=" + sessionId +
289                                                                ". The session may have been closed " +
290                                                                "(e.g. missed heart-beat) while a message was coming in.");
291                                        }
292                                        return;
293                                }
294                        }
295                        else {
296                                Principal principal = session.getPrincipal();
297                                if (principal != null && !principal.equals(request.getPrincipal())) {
298                                        logger.debug("The user for the session does not match the user for the request.");
299                                        response.setStatusCode(HttpStatus.NOT_FOUND);
300                                        return;
301                                }
302                                if (!transportHandler.checkSessionType(session)) {
303                                        logger.debug("Session type does not match the transport type for the request.");
304                                        response.setStatusCode(HttpStatus.NOT_FOUND);
305                                        return;
306                                }
307                        }
308
309                        if (transportType.sendsNoCacheInstruction()) {
310                                addNoCacheHeaders(response);
311                        }
312                        if (transportType.supportsCors() && !checkOrigin(request, response)) {
313                                return;
314                        }
315
316                        transportHandler.handleRequest(request, response, handler, session);
317
318                        if (isNewSession && (response instanceof ServletServerHttpResponse)) {
319                                int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus();
320                                if (HttpStatus.valueOf(status).is4xxClientError()) {
321                                        this.sessions.remove(sessionId);
322                                }
323                        }
324
325                        chain.applyAfterHandshake(request, response, null);
326                }
327                catch (SockJsException ex) {
328                        failure = ex;
329                }
330                catch (Exception ex) {
331                        failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex);
332                }
333                finally {
334                        if (failure != null) {
335                                chain.applyAfterHandshake(request, response, failure);
336                                throw failure;
337                        }
338                }
339        }
340
341        @Override
342        protected boolean validateRequest(String serverId, String sessionId, String transport) {
343                if (!super.validateRequest(serverId, sessionId, transport)) {
344                        return false;
345                }
346
347                if (!this.allowedOrigins.contains("*")) {
348                        TransportType transportType = TransportType.fromValue(transport);
349                        if (transportType == null || !transportType.supportsOrigin()) {
350                                if (logger.isWarnEnabled()) {
351                                        logger.warn("Origin check enabled but transport '" + transport + "' does not support it.");
352                                }
353                                return false;
354                        }
355                }
356
357                return true;
358        }
359
360        private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory,
361                        WebSocketHandler handler, Map<String, Object> attributes) {
362
363                SockJsSession session = this.sessions.get(sessionId);
364                if (session != null) {
365                        return session;
366                }
367                if (this.sessionCleanupTask == null) {
368                        scheduleSessionTask();
369                }
370                session = sessionFactory.createSession(sessionId, handler, attributes);
371                this.sessions.put(sessionId, session);
372                return session;
373        }
374
375        private void scheduleSessionTask() {
376                synchronized (this.sessions) {
377                        if (this.sessionCleanupTask != null) {
378                                return;
379                        }
380                        this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(() -> {
381                                List<String> removedIds = new ArrayList<>();
382                                for (SockJsSession session : this.sessions.values()) {
383                                        try {
384                                                if (session.getTimeSinceLastActive() > getDisconnectDelay()) {
385                                                        this.sessions.remove(session.getId());
386                                                        removedIds.add(session.getId());
387                                                        session.close();
388                                                }
389                                        }
390                                        catch (Throwable ex) {
391                                                // Could be part of normal workflow (e.g. browser tab closed)
392                                                logger.debug("Failed to close " + session, ex);
393                                        }
394                                }
395                                if (logger.isDebugEnabled() && !removedIds.isEmpty()) {
396                                        logger.debug("Closed " + removedIds.size() + " sessions: " + removedIds);
397                                }
398                        }, getDisconnectDelay());
399                }
400        }
401
402}