001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.config.annotation;
018
019import java.util.ArrayList;
020import java.util.Arrays;
021import java.util.List;
022
023import org.springframework.scheduling.TaskScheduler;
024import org.springframework.util.Assert;
025import org.springframework.util.LinkedMultiValueMap;
026import org.springframework.util.MultiValueMap;
027import org.springframework.util.ObjectUtils;
028import org.springframework.util.StringUtils;
029import org.springframework.web.socket.WebSocketHandler;
030import org.springframework.web.socket.server.HandshakeHandler;
031import org.springframework.web.socket.server.HandshakeInterceptor;
032import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
033import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
034import org.springframework.web.socket.sockjs.SockJsService;
035import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
036
037/**
038 * Base class for {@link WebSocketHandlerRegistration}s that gathers all the configuration
039 * options but allows sub-classes to put together the actual HTTP request mappings.
040 *
041 * @author Rossen Stoyanchev
042 * @author Sebastien Deleuze
043 * @since 4.0
044 */
045public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
046
047        private final TaskScheduler sockJsTaskScheduler;
048
049        private final MultiValueMap<WebSocketHandler, String> handlerMap =
050                        new LinkedMultiValueMap<WebSocketHandler, String>();
051
052        private HandshakeHandler handshakeHandler;
053
054        private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
055
056        private final List<String> allowedOrigins = new ArrayList<String>();
057
058        private SockJsServiceRegistration sockJsServiceRegistration;
059
060
061        public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
062                this.sockJsTaskScheduler = defaultTaskScheduler;
063        }
064
065
066        @Override
067        public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
068                Assert.notNull(handler, "WebSocketHandler must not be null");
069                Assert.notEmpty(paths, "Paths must not be empty");
070                this.handlerMap.put(handler, Arrays.asList(paths));
071                return this;
072        }
073
074        @Override
075        public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
076                this.handshakeHandler = handshakeHandler;
077                return this;
078        }
079
080        protected HandshakeHandler getHandshakeHandler() {
081                return this.handshakeHandler;
082        }
083
084        @Override
085        public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) {
086                if (!ObjectUtils.isEmpty(interceptors)) {
087                        this.interceptors.addAll(Arrays.asList(interceptors));
088                }
089                return this;
090        }
091
092        @Override
093        public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) {
094                this.allowedOrigins.clear();
095                if (!ObjectUtils.isEmpty(allowedOrigins)) {
096                        this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
097                }
098                return this;
099        }
100
101        @Override
102        public SockJsServiceRegistration withSockJS() {
103                this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
104                HandshakeInterceptor[] interceptors = getInterceptors();
105                if (interceptors.length > 0) {
106                        this.sockJsServiceRegistration.setInterceptors(interceptors);
107                }
108                if (this.handshakeHandler != null) {
109                        WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
110                        this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler);
111                }
112                if (!this.allowedOrigins.isEmpty()) {
113                        this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
114                }
115                return this.sockJsServiceRegistration;
116        }
117
118        protected HandshakeInterceptor[] getInterceptors() {
119                List<HandshakeInterceptor> interceptors =
120                                new ArrayList<HandshakeInterceptor>(this.interceptors.size() + 1);
121                interceptors.addAll(this.interceptors);
122                interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
123                return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
124        }
125
126        protected final M getMappings() {
127                M mappings = createMappings();
128                if (this.sockJsServiceRegistration != null) {
129                        SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService();
130                        for (WebSocketHandler wsHandler : this.handlerMap.keySet()) {
131                                for (String path : this.handlerMap.get(wsHandler)) {
132                                        String pathPattern = (path.endsWith("/") ? path + "**" : path + "/**");
133                                        addSockJsServiceMapping(mappings, sockJsService, wsHandler, pathPattern);
134                                }
135                        }
136                }
137                else {
138                        HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
139                        HandshakeInterceptor[] interceptors = getInterceptors();
140                        for (WebSocketHandler wsHandler : this.handlerMap.keySet()) {
141                                for (String path : this.handlerMap.get(wsHandler)) {
142                                        addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path);
143                                }
144                        }
145                }
146
147                return mappings;
148        }
149
150        private HandshakeHandler getOrCreateHandshakeHandler() {
151                return (this.handshakeHandler != null ? this.handshakeHandler : new DefaultHandshakeHandler());
152        }
153
154
155        protected abstract M createMappings();
156
157        protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
158                        WebSocketHandler handler, String pathPattern);
159
160        protected abstract void addWebSocketHandlerMapping(M mappings, WebSocketHandler wsHandler,
161                        HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path);
162
163}