001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.config.annotation; 018 019import java.util.ArrayList; 020import java.util.Arrays; 021import java.util.List; 022 023import org.springframework.scheduling.TaskScheduler; 024import org.springframework.util.Assert; 025import org.springframework.util.LinkedMultiValueMap; 026import org.springframework.util.MultiValueMap; 027import org.springframework.util.ObjectUtils; 028import org.springframework.util.StringUtils; 029import org.springframework.web.socket.WebSocketHandler; 030import org.springframework.web.socket.server.HandshakeHandler; 031import org.springframework.web.socket.server.HandshakeInterceptor; 032import org.springframework.web.socket.server.support.DefaultHandshakeHandler; 033import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; 034import org.springframework.web.socket.sockjs.SockJsService; 035import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; 036 037/** 038 * Base class for {@link WebSocketHandlerRegistration}s that gathers all the configuration 039 * options but allows sub-classes to put together the actual HTTP request mappings. 040 * 041 * @author Rossen Stoyanchev 042 * @author Sebastien Deleuze 043 * @since 4.0 044 */ 045public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration { 046 047 private final TaskScheduler sockJsTaskScheduler; 048 049 private final MultiValueMap<WebSocketHandler, String> handlerMap = 050 new LinkedMultiValueMap<WebSocketHandler, String>(); 051 052 private HandshakeHandler handshakeHandler; 053 054 private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); 055 056 private final List<String> allowedOrigins = new ArrayList<String>(); 057 058 private SockJsServiceRegistration sockJsServiceRegistration; 059 060 061 public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) { 062 this.sockJsTaskScheduler = defaultTaskScheduler; 063 } 064 065 066 @Override 067 public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) { 068 Assert.notNull(handler, "WebSocketHandler must not be null"); 069 Assert.notEmpty(paths, "Paths must not be empty"); 070 this.handlerMap.put(handler, Arrays.asList(paths)); 071 return this; 072 } 073 074 @Override 075 public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { 076 this.handshakeHandler = handshakeHandler; 077 return this; 078 } 079 080 protected HandshakeHandler getHandshakeHandler() { 081 return this.handshakeHandler; 082 } 083 084 @Override 085 public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) { 086 if (!ObjectUtils.isEmpty(interceptors)) { 087 this.interceptors.addAll(Arrays.asList(interceptors)); 088 } 089 return this; 090 } 091 092 @Override 093 public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) { 094 this.allowedOrigins.clear(); 095 if (!ObjectUtils.isEmpty(allowedOrigins)) { 096 this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); 097 } 098 return this; 099 } 100 101 @Override 102 public SockJsServiceRegistration withSockJS() { 103 this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler); 104 HandshakeInterceptor[] interceptors = getInterceptors(); 105 if (interceptors.length > 0) { 106 this.sockJsServiceRegistration.setInterceptors(interceptors); 107 } 108 if (this.handshakeHandler != null) { 109 WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); 110 this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler); 111 } 112 if (!this.allowedOrigins.isEmpty()) { 113 this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); 114 } 115 return this.sockJsServiceRegistration; 116 } 117 118 protected HandshakeInterceptor[] getInterceptors() { 119 List<HandshakeInterceptor> interceptors = 120 new ArrayList<HandshakeInterceptor>(this.interceptors.size() + 1); 121 interceptors.addAll(this.interceptors); 122 interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); 123 return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); 124 } 125 126 protected final M getMappings() { 127 M mappings = createMappings(); 128 if (this.sockJsServiceRegistration != null) { 129 SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(); 130 for (WebSocketHandler wsHandler : this.handlerMap.keySet()) { 131 for (String path : this.handlerMap.get(wsHandler)) { 132 String pathPattern = (path.endsWith("/") ? path + "**" : path + "/**"); 133 addSockJsServiceMapping(mappings, sockJsService, wsHandler, pathPattern); 134 } 135 } 136 } 137 else { 138 HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler(); 139 HandshakeInterceptor[] interceptors = getInterceptors(); 140 for (WebSocketHandler wsHandler : this.handlerMap.keySet()) { 141 for (String path : this.handlerMap.get(wsHandler)) { 142 addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path); 143 } 144 } 145 } 146 147 return mappings; 148 } 149 150 private HandshakeHandler getOrCreateHandshakeHandler() { 151 return (this.handshakeHandler != null ? this.handshakeHandler : new DefaultHandshakeHandler()); 152 } 153 154 155 protected abstract M createMappings(); 156 157 protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService, 158 WebSocketHandler handler, String pathPattern); 159 160 protected abstract void addWebSocketHandlerMapping(M mappings, WebSocketHandler wsHandler, 161 HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path); 162 163}