001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.config.annotation; 018 019import java.util.ArrayList; 020import java.util.Arrays; 021import java.util.List; 022 023import org.springframework.lang.Nullable; 024import org.springframework.util.Assert; 025import org.springframework.util.LinkedMultiValueMap; 026import org.springframework.util.MultiValueMap; 027import org.springframework.util.ObjectUtils; 028import org.springframework.util.StringUtils; 029import org.springframework.web.socket.WebSocketHandler; 030import org.springframework.web.socket.server.HandshakeHandler; 031import org.springframework.web.socket.server.HandshakeInterceptor; 032import org.springframework.web.socket.server.support.DefaultHandshakeHandler; 033import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; 034import org.springframework.web.socket.sockjs.SockJsService; 035import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; 036 037/** 038 * Base class for {@link WebSocketHandlerRegistration WebSocketHandlerRegistrations} that gathers all the configuration 039 * options but allows sub-classes to put together the actual HTTP request mappings. 040 * 041 * @author Rossen Stoyanchev 042 * @author Sebastien Deleuze 043 * @since 4.0 044 * @param <M> the mappings type 045 */ 046public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration { 047 048 private final MultiValueMap<WebSocketHandler, String> handlerMap = new LinkedMultiValueMap<>(); 049 050 @Nullable 051 private HandshakeHandler handshakeHandler; 052 053 private final List<HandshakeInterceptor> interceptors = new ArrayList<>(); 054 055 private final List<String> allowedOrigins = new ArrayList<>(); 056 057 @Nullable 058 private SockJsServiceRegistration sockJsServiceRegistration; 059 060 061 @Override 062 public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) { 063 Assert.notNull(handler, "WebSocketHandler must not be null"); 064 Assert.notEmpty(paths, "Paths must not be empty"); 065 this.handlerMap.put(handler, Arrays.asList(paths)); 066 return this; 067 } 068 069 @Override 070 public WebSocketHandlerRegistration setHandshakeHandler(@Nullable HandshakeHandler handshakeHandler) { 071 this.handshakeHandler = handshakeHandler; 072 return this; 073 } 074 075 @Nullable 076 protected HandshakeHandler getHandshakeHandler() { 077 return this.handshakeHandler; 078 } 079 080 @Override 081 public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) { 082 if (!ObjectUtils.isEmpty(interceptors)) { 083 this.interceptors.addAll(Arrays.asList(interceptors)); 084 } 085 return this; 086 } 087 088 @Override 089 public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) { 090 this.allowedOrigins.clear(); 091 if (!ObjectUtils.isEmpty(allowedOrigins)) { 092 this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); 093 } 094 return this; 095 } 096 097 @Override 098 public SockJsServiceRegistration withSockJS() { 099 this.sockJsServiceRegistration = new SockJsServiceRegistration(); 100 HandshakeInterceptor[] interceptors = getInterceptors(); 101 if (interceptors.length > 0) { 102 this.sockJsServiceRegistration.setInterceptors(interceptors); 103 } 104 if (this.handshakeHandler != null) { 105 WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); 106 this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler); 107 } 108 if (!this.allowedOrigins.isEmpty()) { 109 this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); 110 } 111 return this.sockJsServiceRegistration; 112 } 113 114 protected HandshakeInterceptor[] getInterceptors() { 115 List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1); 116 interceptors.addAll(this.interceptors); 117 interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); 118 return interceptors.toArray(new HandshakeInterceptor[0]); 119 } 120 121 /** 122 * Expose the {@code SockJsServiceRegistration} -- if SockJS is enabled or 123 * {@code null} otherwise -- so that it can be configured with a TaskScheduler 124 * if the application did not provide one. This should be done prior to 125 * calling {@link #getMappings()}. 126 */ 127 @Nullable 128 protected SockJsServiceRegistration getSockJsServiceRegistration() { 129 return this.sockJsServiceRegistration; 130 } 131 132 protected final M getMappings() { 133 M mappings = createMappings(); 134 if (this.sockJsServiceRegistration != null) { 135 SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(); 136 this.handlerMap.forEach((wsHandler, paths) -> { 137 for (String path : paths) { 138 String pathPattern = (path.endsWith("/") ? path + "**" : path + "/**"); 139 addSockJsServiceMapping(mappings, sockJsService, wsHandler, pathPattern); 140 } 141 }); 142 } 143 else { 144 HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler(); 145 HandshakeInterceptor[] interceptors = getInterceptors(); 146 this.handlerMap.forEach((wsHandler, paths) -> { 147 for (String path : paths) { 148 addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path); 149 } 150 }); 151 } 152 153 return mappings; 154 } 155 156 private HandshakeHandler getOrCreateHandshakeHandler() { 157 return (this.handshakeHandler != null ? this.handshakeHandler : new DefaultHandshakeHandler()); 158 } 159 160 161 protected abstract M createMappings(); 162 163 protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService, 164 WebSocketHandler handler, String pathPattern); 165 166 protected abstract void addWebSocketHandlerMapping(M mappings, WebSocketHandler wsHandler, 167 HandshakeHandler handshakeHandler, HandshakeInterceptor[] interceptors, String path); 168 169}