001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.config.annotation; 018 019import java.util.ArrayList; 020import java.util.Arrays; 021import java.util.List; 022 023import org.springframework.lang.Nullable; 024import org.springframework.scheduling.TaskScheduler; 025import org.springframework.util.Assert; 026import org.springframework.util.LinkedMultiValueMap; 027import org.springframework.util.MultiValueMap; 028import org.springframework.util.ObjectUtils; 029import org.springframework.util.StringUtils; 030import org.springframework.web.HttpRequestHandler; 031import org.springframework.web.socket.WebSocketHandler; 032import org.springframework.web.socket.server.HandshakeHandler; 033import org.springframework.web.socket.server.HandshakeInterceptor; 034import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; 035import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; 036import org.springframework.web.socket.sockjs.SockJsService; 037import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; 038import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; 039 040/** 041 * An abstract base class for configuring STOMP over WebSocket/SockJS endpoints. 042 * 043 * @author Rossen Stoyanchev 044 * @since 4.0 045 */ 046public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketEndpointRegistration { 047 048 private final String[] paths; 049 050 private final WebSocketHandler webSocketHandler; 051 052 private final TaskScheduler sockJsTaskScheduler; 053 054 @Nullable 055 private HandshakeHandler handshakeHandler; 056 057 private final List<HandshakeInterceptor> interceptors = new ArrayList<>(); 058 059 private final List<String> allowedOrigins = new ArrayList<>(); 060 061 @Nullable 062 private SockJsServiceRegistration registration; 063 064 065 public WebMvcStompWebSocketEndpointRegistration( 066 String[] paths, WebSocketHandler webSocketHandler, TaskScheduler sockJsTaskScheduler) { 067 068 Assert.notEmpty(paths, "No paths specified"); 069 Assert.notNull(webSocketHandler, "WebSocketHandler must not be null"); 070 071 this.paths = paths; 072 this.webSocketHandler = webSocketHandler; 073 this.sockJsTaskScheduler = sockJsTaskScheduler; 074 } 075 076 077 @Override 078 public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { 079 this.handshakeHandler = handshakeHandler; 080 return this; 081 } 082 083 @Override 084 public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) { 085 if (!ObjectUtils.isEmpty(interceptors)) { 086 this.interceptors.addAll(Arrays.asList(interceptors)); 087 } 088 return this; 089 } 090 091 @Override 092 public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOrigins) { 093 this.allowedOrigins.clear(); 094 if (!ObjectUtils.isEmpty(allowedOrigins)) { 095 this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); 096 } 097 return this; 098 } 099 100 @Override 101 public SockJsServiceRegistration withSockJS() { 102 this.registration = new SockJsServiceRegistration(); 103 this.registration.setTaskScheduler(this.sockJsTaskScheduler); 104 HandshakeInterceptor[] interceptors = getInterceptors(); 105 if (interceptors.length > 0) { 106 this.registration.setInterceptors(interceptors); 107 } 108 if (this.handshakeHandler != null) { 109 WebSocketTransportHandler handler = new WebSocketTransportHandler(this.handshakeHandler); 110 this.registration.setTransportHandlerOverrides(handler); 111 } 112 if (!this.allowedOrigins.isEmpty()) { 113 this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); 114 } 115 return this.registration; 116 } 117 118 protected HandshakeInterceptor[] getInterceptors() { 119 List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1); 120 interceptors.addAll(this.interceptors); 121 interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); 122 return interceptors.toArray(new HandshakeInterceptor[0]); 123 } 124 125 public final MultiValueMap<HttpRequestHandler, String> getMappings() { 126 MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<>(); 127 if (this.registration != null) { 128 SockJsService sockJsService = this.registration.getSockJsService(); 129 for (String path : this.paths) { 130 String pattern = (path.endsWith("/") ? path + "**" : path + "/**"); 131 SockJsHttpRequestHandler handler = new SockJsHttpRequestHandler(sockJsService, this.webSocketHandler); 132 mappings.add(handler, pattern); 133 } 134 } 135 else { 136 for (String path : this.paths) { 137 WebSocketHttpRequestHandler handler; 138 if (this.handshakeHandler != null) { 139 handler = new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler); 140 } 141 else { 142 handler = new WebSocketHttpRequestHandler(this.webSocketHandler); 143 } 144 HandshakeInterceptor[] interceptors = getInterceptors(); 145 if (interceptors.length > 0) { 146 handler.setHandshakeInterceptors(Arrays.asList(interceptors)); 147 } 148 mappings.add(handler, path); 149 } 150 } 151 return mappings; 152 } 153 154}