001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.transport;
018
019import java.io.IOException;
020import java.util.ArrayList;
021import java.util.Arrays;
022import java.util.Collection;
023import java.util.Collections;
024import java.util.EnumMap;
025import java.util.HashMap;
026import java.util.List;
027import java.util.Map;
028import java.util.concurrent.ConcurrentHashMap;
029import java.util.concurrent.ScheduledFuture;
030
031import org.springframework.context.Lifecycle;
032import org.springframework.http.HttpMethod;
033import org.springframework.http.HttpStatus;
034import org.springframework.http.server.ServerHttpRequest;
035import org.springframework.http.server.ServerHttpResponse;
036import org.springframework.scheduling.TaskScheduler;
037import org.springframework.util.Assert;
038import org.springframework.util.ClassUtils;
039import org.springframework.util.CollectionUtils;
040import org.springframework.web.socket.WebSocketHandler;
041import org.springframework.web.socket.server.HandshakeFailureException;
042import org.springframework.web.socket.server.HandshakeHandler;
043import org.springframework.web.socket.server.HandshakeInterceptor;
044import org.springframework.web.socket.server.support.HandshakeInterceptorChain;
045import org.springframework.web.socket.sockjs.SockJsException;
046import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
047import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
048import org.springframework.web.socket.sockjs.support.AbstractSockJsService;
049
050/**
051 * A basic implementation of {@link org.springframework.web.socket.sockjs.SockJsService}
052 * with support for SPI-based transport handling and session management.
053 *
054 * <p>Based on the {@link TransportHandler} SPI. {@link TransportHandler}s may additionally
055 * implement the {@link SockJsSessionFactory} and {@link HandshakeHandler} interfaces.
056 *
057 * <p>See the {@link AbstractSockJsService} base class for important details on request mapping.
058 *
059 * @author Rossen Stoyanchev
060 * @author Juergen Hoeller
061 * @author Sebastien Deleuze
062 * @since 4.0
063 */
064public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig, Lifecycle {
065
066        private static final boolean jackson2Present = ClassUtils.isPresent(
067                        "com.fasterxml.jackson.databind.ObjectMapper", TransportHandlingSockJsService.class.getClassLoader());
068
069
070        private final Map<TransportType, TransportHandler> handlers =
071                        new EnumMap<TransportType, TransportHandler>(TransportType.class);
072
073        private SockJsMessageCodec messageCodec;
074
075        private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
076
077        private final Map<String, SockJsSession> sessions = new ConcurrentHashMap<String, SockJsSession>();
078
079        private ScheduledFuture<?> sessionCleanupTask;
080
081        private volatile boolean running;
082
083
084        /**
085         * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types.
086         * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions;
087         * the provided TaskScheduler should be declared as a Spring bean to ensure it gets
088         * initialized at start-up and shuts down when the application stops
089         * @param handlers one or more {@link TransportHandler} implementations to use
090         */
091        public TransportHandlingSockJsService(TaskScheduler scheduler, TransportHandler... handlers) {
092                this(scheduler, Arrays.asList(handlers));
093        }
094
095        /**
096         * Create a TransportHandlingSockJsService with given {@link TransportHandler handler} types.
097         * @param scheduler a task scheduler for heart-beat messages and removing timed-out sessions;
098         * the provided TaskScheduler should be declared as a Spring bean to ensure it gets
099         * initialized at start-up and shuts down when the application stops
100         * @param handlers one or more {@link TransportHandler} implementations to use
101         */
102        public TransportHandlingSockJsService(TaskScheduler scheduler, Collection<TransportHandler> handlers) {
103                super(scheduler);
104
105                if (CollectionUtils.isEmpty(handlers)) {
106                        logger.warn("No transport handlers specified for TransportHandlingSockJsService");
107                }
108                else {
109                        for (TransportHandler handler : handlers) {
110                                handler.initialize(this);
111                                this.handlers.put(handler.getTransportType(), handler);
112                        }
113                }
114
115                if (jackson2Present) {
116                        this.messageCodec = new Jackson2SockJsMessageCodec();
117                }
118        }
119
120
121        /**
122         * Return the registered handlers per transport type.
123         */
124        public Map<TransportType, TransportHandler> getTransportHandlers() {
125                return Collections.unmodifiableMap(this.handlers);
126        }
127
128        /**
129         * The codec to use for encoding and decoding SockJS messages.
130         */
131        public void setMessageCodec(SockJsMessageCodec messageCodec) {
132                this.messageCodec = messageCodec;
133        }
134
135        public SockJsMessageCodec getMessageCodec() {
136                Assert.state(this.messageCodec != null, "A SockJsMessageCodec is required but not available: " +
137                                "Add Jackson to the classpath, or configure a custom SockJsMessageCodec.");
138                return this.messageCodec;
139        }
140
141        /**
142         * Configure one or more WebSocket handshake request interceptors.
143         */
144        public void setHandshakeInterceptors(List<HandshakeInterceptor> interceptors) {
145                this.interceptors.clear();
146                if (interceptors != null) {
147                        this.interceptors.addAll(interceptors);
148                }
149        }
150
151        /**
152         * Return the configured WebSocket handshake request interceptors.
153         */
154        public List<HandshakeInterceptor> getHandshakeInterceptors() {
155                return this.interceptors;
156        }
157
158
159        @Override
160        public void start() {
161                if (!isRunning()) {
162                        this.running = true;
163                        for (TransportHandler handler : this.handlers.values()) {
164                                if (handler instanceof Lifecycle) {
165                                        ((Lifecycle) handler).start();
166                                }
167                        }
168                }
169        }
170
171        @Override
172        public void stop() {
173                if (isRunning()) {
174                        this.running = false;
175                        for (TransportHandler handler : this.handlers.values()) {
176                                if (handler instanceof Lifecycle) {
177                                        ((Lifecycle) handler).stop();
178                                }
179                        }
180                }
181        }
182
183        @Override
184        public boolean isRunning() {
185                return this.running;
186        }
187
188
189        @Override
190        protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response,
191                        WebSocketHandler handler) throws IOException {
192
193                TransportHandler transportHandler = this.handlers.get(TransportType.WEBSOCKET);
194                if (!(transportHandler instanceof HandshakeHandler)) {
195                        logger.error("No handler configured for raw WebSocket messages");
196                        response.setStatusCode(HttpStatus.NOT_FOUND);
197                        return;
198                }
199
200                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
201                HandshakeFailureException failure = null;
202
203                try {
204                        Map<String, Object> attributes = new HashMap<String, Object>();
205                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
206                                return;
207                        }
208                        ((HandshakeHandler) transportHandler).doHandshake(request, response, handler, attributes);
209                        chain.applyAfterHandshake(request, response, null);
210                }
211                catch (HandshakeFailureException ex) {
212                        failure = ex;
213                }
214                catch (Throwable ex) {
215                        failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
216                }
217                        finally {
218                        if (failure != null) {
219                                chain.applyAfterHandshake(request, response, failure);
220                                throw failure;
221                        }
222                }
223        }
224
225        @Override
226        protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
227                        WebSocketHandler handler, String sessionId, String transport) throws SockJsException {
228
229                TransportType transportType = TransportType.fromValue(transport);
230                if (transportType == null) {
231                        if (logger.isWarnEnabled()) {
232                                logger.warn("Unknown transport type for " + request.getURI());
233                        }
234                        response.setStatusCode(HttpStatus.NOT_FOUND);
235                        return;
236                }
237
238                TransportHandler transportHandler = this.handlers.get(transportType);
239                if (transportHandler == null) {
240                        if (logger.isWarnEnabled()) {
241                                logger.warn("No TransportHandler for " + request.getURI());
242                        }
243                        response.setStatusCode(HttpStatus.NOT_FOUND);
244                        return;
245                }
246
247                SockJsException failure = null;
248                HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
249
250                try {
251                        HttpMethod supportedMethod = transportType.getHttpMethod();
252                        if (supportedMethod != request.getMethod()) {
253                                if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) {
254                                        if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) {
255                                                response.setStatusCode(HttpStatus.NO_CONTENT);
256                                                addCacheHeaders(response);
257                                        }
258                                }
259                                else if (transportType.supportsCors()) {
260                                        sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
261                                }
262                                else {
263                                        sendMethodNotAllowed(response, supportedMethod);
264                                }
265                                return;
266                        }
267
268                        SockJsSession session = this.sessions.get(sessionId);
269                        if (session == null) {
270                                if (transportHandler instanceof SockJsSessionFactory) {
271                                        Map<String, Object> attributes = new HashMap<String, Object>();
272                                        if (!chain.applyBeforeHandshake(request, response, attributes)) {
273                                                return;
274                                        }
275                                        SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler;
276                                        session = createSockJsSession(sessionId, sessionFactory, handler, attributes);
277                                }
278                                else {
279                                        response.setStatusCode(HttpStatus.NOT_FOUND);
280                                        if (logger.isDebugEnabled()) {
281                                                logger.debug("Session not found, sessionId=" + sessionId +
282                                                                ". The session may have been closed " +
283                                                                "(e.g. missed heart-beat) while a message was coming in.");
284                                        }
285                                        return;
286                                }
287                        }
288                        else {
289                                if (session.getPrincipal() != null) {
290                                        if (!session.getPrincipal().equals(request.getPrincipal())) {
291                                                logger.debug("The user for the session does not match the user for the request.");
292                                                response.setStatusCode(HttpStatus.NOT_FOUND);
293                                                return;
294                                        }
295                                }
296                                if (!transportHandler.checkSessionType(session)) {
297                                        logger.debug("Session type does not match the transport type for the request.");
298                                        response.setStatusCode(HttpStatus.NOT_FOUND);
299                                        return;
300                                }
301                        }
302
303                        if (transportType.sendsNoCacheInstruction()) {
304                                addNoCacheHeaders(response);
305                        }
306
307                        if (transportType.supportsCors()) {
308                                if (!checkOrigin(request, response)) {
309                                        return;
310                                }
311                        }
312
313
314                        transportHandler.handleRequest(request, response, handler, session);
315
316
317                        chain.applyAfterHandshake(request, response, null);
318                }
319                catch (SockJsException ex) {
320                        failure = ex;
321                }
322                catch (Throwable ex) {
323                        failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex);
324                }
325                finally {
326                        if (failure != null) {
327                                chain.applyAfterHandshake(request, response, failure);
328                                throw failure;
329                        }
330                }
331        }
332
333        @Override
334        protected boolean validateRequest(String serverId, String sessionId, String transport) {
335                if (!super.validateRequest(serverId, sessionId, transport)) {
336                        return false;
337                }
338
339                if (!this.allowedOrigins.contains("*")) {
340                        TransportType transportType = TransportType.fromValue(transport);
341                        if (transportType == null || !transportType.supportsOrigin()) {
342                                if (logger.isWarnEnabled()) {
343                                        logger.warn("Origin check enabled but transport '" + transport + "' does not support it.");
344                                }
345                                return false;
346                        }
347                }
348
349                return true;
350        }
351
352        private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory,
353                        WebSocketHandler handler, Map<String, Object> attributes) {
354
355                SockJsSession session = this.sessions.get(sessionId);
356                if (session != null) {
357                        return session;
358                }
359                if (this.sessionCleanupTask == null) {
360                        scheduleSessionTask();
361                }
362                session = sessionFactory.createSession(sessionId, handler, attributes);
363                this.sessions.put(sessionId, session);
364                return session;
365        }
366
367        private void scheduleSessionTask() {
368                synchronized (this.sessions) {
369                        if (this.sessionCleanupTask != null) {
370                                return;
371                        }
372                        this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(new Runnable() {
373                                @Override
374                                public void run() {
375                                        List<String> removedIds = new ArrayList<String>();
376                                        for (SockJsSession session : sessions.values()) {
377                                                try {
378                                                        if (session.getTimeSinceLastActive() > getDisconnectDelay()) {
379                                                                sessions.remove(session.getId());
380                                                                removedIds.add(session.getId());
381                                                                session.close();
382                                                        }
383                                                }
384                                                catch (Throwable ex) {
385                                                        // Could be part of normal workflow (e.g. browser tab closed)
386                                                        logger.debug("Failed to close " + session, ex);
387                                                }
388                                        }
389                                        if (logger.isDebugEnabled() && !removedIds.isEmpty()) {
390                                                logger.debug("Closed " + removedIds.size() + " sessions: " + removedIds);
391                                        }
392                                }
393                        }, getDisconnectDelay());
394                }
395        }
396
397}