001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.config.annotation; 018 019import java.util.ArrayList; 020import java.util.Arrays; 021import java.util.List; 022 023import org.springframework.scheduling.TaskScheduler; 024import org.springframework.util.Assert; 025import org.springframework.util.LinkedMultiValueMap; 026import org.springframework.util.MultiValueMap; 027import org.springframework.util.ObjectUtils; 028import org.springframework.util.StringUtils; 029import org.springframework.web.HttpRequestHandler; 030import org.springframework.web.socket.WebSocketHandler; 031import org.springframework.web.socket.server.HandshakeHandler; 032import org.springframework.web.socket.server.HandshakeInterceptor; 033import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; 034import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; 035import org.springframework.web.socket.sockjs.SockJsService; 036import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; 037import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; 038 039/** 040 * An abstract base class for configuring STOMP over WebSocket/SockJS endpoints. 041 * 042 * @author Rossen Stoyanchev 043 * @since 4.0 044 */ 045public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketEndpointRegistration { 046 047 private final String[] paths; 048 049 private final WebSocketHandler webSocketHandler; 050 051 private final TaskScheduler sockJsTaskScheduler; 052 053 private HandshakeHandler handshakeHandler; 054 055 private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); 056 057 private final List<String> allowedOrigins = new ArrayList<String>(); 058 059 private SockJsServiceRegistration registration; 060 061 062 public WebMvcStompWebSocketEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler, 063 TaskScheduler sockJsTaskScheduler) { 064 065 Assert.notEmpty(paths, "No paths specified"); 066 Assert.notNull(webSocketHandler, "WebSocketHandler must not be null"); 067 068 this.paths = paths; 069 this.webSocketHandler = webSocketHandler; 070 this.sockJsTaskScheduler = sockJsTaskScheduler; 071 } 072 073 074 @Override 075 public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { 076 this.handshakeHandler = handshakeHandler; 077 return this; 078 } 079 080 @Override 081 public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) { 082 if (!ObjectUtils.isEmpty(interceptors)) { 083 this.interceptors.addAll(Arrays.asList(interceptors)); 084 } 085 return this; 086 } 087 088 @Override 089 public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOrigins) { 090 this.allowedOrigins.clear(); 091 if (!ObjectUtils.isEmpty(allowedOrigins)) { 092 this.allowedOrigins.addAll(Arrays.asList(allowedOrigins)); 093 } 094 return this; 095 } 096 097 @Override 098 public SockJsServiceRegistration withSockJS() { 099 this.registration = new SockJsServiceRegistration(this.sockJsTaskScheduler); 100 HandshakeInterceptor[] interceptors = getInterceptors(); 101 if (interceptors.length > 0) { 102 this.registration.setInterceptors(interceptors); 103 } 104 if (this.handshakeHandler != null) { 105 WebSocketTransportHandler handler = new WebSocketTransportHandler(this.handshakeHandler); 106 this.registration.setTransportHandlerOverrides(handler); 107 } 108 if (!this.allowedOrigins.isEmpty()) { 109 this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); 110 } 111 return this.registration; 112 } 113 114 protected HandshakeInterceptor[] getInterceptors() { 115 List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(this.interceptors.size() + 1); 116 interceptors.addAll(this.interceptors); 117 interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); 118 return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); 119 } 120 121 public final MultiValueMap<HttpRequestHandler, String> getMappings() { 122 MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>(); 123 if (this.registration != null) { 124 SockJsService sockJsService = this.registration.getSockJsService(); 125 for (String path : this.paths) { 126 String pattern = (path.endsWith("/") ? path + "**" : path + "/**"); 127 SockJsHttpRequestHandler handler = new SockJsHttpRequestHandler(sockJsService, this.webSocketHandler); 128 mappings.add(handler, pattern); 129 } 130 } 131 else { 132 for (String path : this.paths) { 133 WebSocketHttpRequestHandler handler; 134 if (this.handshakeHandler != null) { 135 handler = new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler); 136 } 137 else { 138 handler = new WebSocketHttpRequestHandler(this.webSocketHandler); 139 } 140 HandshakeInterceptor[] interceptors = getInterceptors(); 141 if (interceptors.length > 0) { 142 handler.setHandshakeInterceptors(Arrays.asList(interceptors)); 143 } 144 mappings.add(handler, path); 145 } 146 } 147 return mappings; 148 } 149 150}