001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.config.annotation;
018
019import java.util.ArrayList;
020import java.util.Arrays;
021import java.util.List;
022
023import org.springframework.lang.Nullable;
024import org.springframework.util.Assert;
025import org.springframework.util.LinkedMultiValueMap;
026import org.springframework.util.MultiValueMap;
027import org.springframework.util.ObjectUtils;
028import org.springframework.util.StringUtils;
029import org.springframework.web.socket.WebSocketHandler;
030import org.springframework.web.socket.server.HandshakeHandler;
031import org.springframework.web.socket.server.HandshakeInterceptor;
032import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
033import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
034import org.springframework.web.socket.sockjs.SockJsService;
035import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
036
037/**
038 * Base class for {@link WebSocketHandlerRegistration WebSocketHandlerRegistrations} that gathers all the configuration
039 * options but allows sub-classes to put together the actual HTTP request mappings.
040 *
041 * @author Rossen Stoyanchev
042 * @author Sebastien Deleuze
043 * @since 4.0
044 * @param <M> the mappings type
045 */
046public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
047
048        private final MultiValueMap<WebSocketHandler, String> handlerMap = new LinkedMultiValueMap<>();
049
050        @Nullable
051        private HandshakeHandler handshakeHandler;
052
053        private final List<HandshakeInterceptor> interceptors = new ArrayList<>();
054
055        private final List<String> allowedOrigins = new ArrayList<>();
056
057        @Nullable
058        private SockJsServiceRegistration sockJsServiceRegistration;
059
060
061        @Override
062        public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
063                Assert.notNull(handler, "WebSocketHandler must not be null");
064                Assert.notEmpty(paths, "Paths must not be empty");
065                this.handlerMap.put(handler, Arrays.asList(paths));
066                return this;
067        }
068
069        @Override
070        public WebSocketHandlerRegistration setHandshakeHandler(@Nullable HandshakeHandler handshakeHandler) {
071                this.handshakeHandler = handshakeHandler;
072                return this;
073        }
074
075        @Nullable
076        protected HandshakeHandler getHandshakeHandler() {
077                return this.handshakeHandler;
078        }
079
080        @Override
081        public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) {
082                if (!ObjectUtils.isEmpty(interceptors)) {
083                        this.interceptors.addAll(Arrays.asList(interceptors));
084                }
085                return this;
086        }
087
088        @Override
089        public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) {
090                this.allowedOrigins.clear();
091                if (!ObjectUtils.isEmpty(allowedOrigins)) {
092                        this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
093                }
094                return this;
095        }
096
097        @Override
098        public SockJsServiceRegistration withSockJS() {
099                this.sockJsServiceRegistration = new SockJsServiceRegistration();
100                HandshakeInterceptor[] interceptors = getInterceptors();
101                if (interceptors.length > 0) {
102                        this.sockJsServiceRegistration.setInterceptors(interceptors);
103                }
104                if (this.handshakeHandler != null) {
105                        WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
106                        this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler);
107                }
108                if (!this.allowedOrigins.isEmpty()) {
109                        this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
110                }
111                return this.sockJsServiceRegistration;
112        }
113
114        protected HandshakeInterceptor[] getInterceptors() {
115                List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1);
116                interceptors.addAll(this.interceptors);
117                interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
118                return interceptors.toArray(new HandshakeInterceptor[0]);
119        }
120
121        /**
122         * Expose the {@code SockJsServiceRegistration} -- if SockJS is enabled or
123         * {@code null} otherwise -- so that it can be configured with a TaskScheduler
124         * if the application did not provide one. This should be done prior to
125         * calling {@link #getMappings()}.
126         */
127        @Nullable
128        protected SockJsServiceRegistration getSockJsServiceRegistration() {
129                return this.sockJsServiceRegistration;
130        }
131
132        protected final M getMappings() {
133                M mappings = createMappings();
134                if (this.sockJsServiceRegistration != null) {
135                        SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService();
136                        this.handlerMap.forEach((wsHandler, paths) -> {
137                                for (String path : paths) {
138                                        String pathPattern = (path.endsWith("/") ? path + "**" : path + "/**");
139                                        addSockJsServiceMapping(mappings, sockJsService, wsHandler, pathPattern);
140                                }
141                        });
142                }
143                else {
144                        HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
145                        HandshakeInterceptor[] interceptors = getInterceptors();
146                        this.handlerMap.forEach((wsHandler, paths) -> {
147                                for (String path : paths) {
148                                        addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path);
149                                }
150                        });
151                }
152
153                return mappings;
154        }
155
156        private HandshakeHandler getOrCreateHandshakeHandler() {
157                return (this.handshakeHandler != null ? this.handshakeHandler : new DefaultHandshakeHandler());
158        }
159
160
161        protected abstract M createMappings();
162
163        protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
164                        WebSocketHandler handler, String pathPattern);
165
166        protected abstract void addWebSocketHandlerMapping(M mappings, WebSocketHandler wsHandler,
167                        HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path);
168
169}