001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.client.standard; 018 019import java.net.InetAddress; 020import java.net.InetSocketAddress; 021import java.net.URI; 022import java.net.UnknownHostException; 023import java.util.ArrayList; 024import java.util.HashMap; 025import java.util.List; 026import java.util.Locale; 027import java.util.Map; 028import java.util.concurrent.Callable; 029 030import javax.websocket.ClientEndpointConfig; 031import javax.websocket.ClientEndpointConfig.Configurator; 032import javax.websocket.ContainerProvider; 033import javax.websocket.Endpoint; 034import javax.websocket.Extension; 035import javax.websocket.HandshakeResponse; 036import javax.websocket.WebSocketContainer; 037 038import org.springframework.core.task.AsyncListenableTaskExecutor; 039import org.springframework.core.task.SimpleAsyncTaskExecutor; 040import org.springframework.core.task.TaskExecutor; 041import org.springframework.http.HttpHeaders; 042import org.springframework.lang.Nullable; 043import org.springframework.util.Assert; 044import org.springframework.util.concurrent.ListenableFuture; 045import org.springframework.util.concurrent.ListenableFutureTask; 046import org.springframework.web.socket.WebSocketExtension; 047import org.springframework.web.socket.WebSocketHandler; 048import org.springframework.web.socket.WebSocketSession; 049import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; 050import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; 051import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; 052import org.springframework.web.socket.client.AbstractWebSocketClient; 053 054/** 055 * A WebSocketClient based on standard Java WebSocket API. 056 * 057 * @author Rossen Stoyanchev 058 * @since 4.0 059 */ 060public class StandardWebSocketClient extends AbstractWebSocketClient { 061 062 private final WebSocketContainer webSocketContainer; 063 064 private final Map<String,Object> userProperties = new HashMap<>(); 065 066 @Nullable 067 private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); 068 069 070 /** 071 * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()} 072 * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor 073 * accepting existing {@code WebSocketContainer} instance. 074 */ 075 public StandardWebSocketClient() { 076 this.webSocketContainer = ContainerProvider.getWebSocketContainer(); 077 } 078 079 /** 080 * Constructor accepting an existing {@link WebSocketContainer} instance. 081 * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java 082 * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain 083 * the {@code WebSocketContainer} instance. 084 */ 085 public StandardWebSocketClient(WebSocketContainer webSocketContainer) { 086 Assert.notNull(webSocketContainer, "WebSocketContainer must not be null"); 087 this.webSocketContainer = webSocketContainer; 088 } 089 090 091 /** 092 * The standard Java WebSocket API allows passing "user properties" to the 093 * server via {@link ClientEndpointConfig#getUserProperties() userProperties}. 094 * Use this property to configure one or more properties to be passed on 095 * every handshake. 096 */ 097 public void setUserProperties(@Nullable Map<String, Object> userProperties) { 098 if (userProperties != null) { 099 this.userProperties.putAll(userProperties); 100 } 101 } 102 103 /** 104 * The configured user properties. 105 */ 106 public Map<String, Object> getUserProperties() { 107 return this.userProperties; 108 } 109 110 /** 111 * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. 112 * If this property is set to {@code null}, calls to any of the 113 * {@code doHandshake} methods will block until the connection is established. 114 * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used. 115 */ 116 public void setTaskExecutor(@Nullable AsyncListenableTaskExecutor taskExecutor) { 117 this.taskExecutor = taskExecutor; 118 } 119 120 /** 121 * Return the configured {@link TaskExecutor}. 122 */ 123 @Nullable 124 public AsyncListenableTaskExecutor getTaskExecutor() { 125 return this.taskExecutor; 126 } 127 128 129 @Override 130 protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler, 131 HttpHeaders headers, final URI uri, List<String> protocols, 132 List<WebSocketExtension> extensions, Map<String, Object> attributes) { 133 134 int port = getPort(uri); 135 InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); 136 InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); 137 138 final StandardWebSocketSession session = new StandardWebSocketSession(headers, 139 attributes, localAddress, remoteAddress); 140 141 final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() 142 .configurator(new StandardWebSocketClientConfigurator(headers)) 143 .preferredSubprotocols(protocols) 144 .extensions(adaptExtensions(extensions)).build(); 145 146 endpointConfig.getUserProperties().putAll(getUserProperties()); 147 148 final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); 149 150 Callable<WebSocketSession> connectTask = () -> { 151 this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri); 152 return session; 153 }; 154 155 if (this.taskExecutor != null) { 156 return this.taskExecutor.submitListenable(connectTask); 157 } 158 else { 159 ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<>(connectTask); 160 task.run(); 161 return task; 162 } 163 } 164 165 private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) { 166 List<Extension> result = new ArrayList<>(); 167 for (WebSocketExtension extension : extensions) { 168 result.add(new WebSocketToStandardExtensionAdapter(extension)); 169 } 170 return result; 171 } 172 173 private InetAddress getLocalHost() { 174 try { 175 return InetAddress.getLocalHost(); 176 } 177 catch (UnknownHostException ex) { 178 return InetAddress.getLoopbackAddress(); 179 } 180 } 181 182 private int getPort(URI uri) { 183 if (uri.getPort() == -1) { 184 String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH); 185 return ("wss".equals(scheme) ? 443 : 80); 186 } 187 return uri.getPort(); 188 } 189 190 191 private class StandardWebSocketClientConfigurator extends Configurator { 192 193 private final HttpHeaders headers; 194 195 public StandardWebSocketClientConfigurator(HttpHeaders headers) { 196 this.headers = headers; 197 } 198 199 @Override 200 public void beforeRequest(Map<String, List<String>> requestHeaders) { 201 requestHeaders.putAll(this.headers); 202 if (logger.isTraceEnabled()) { 203 logger.trace("Handshake request headers: " + requestHeaders); 204 } 205 } 206 @Override 207 public void afterResponse(HandshakeResponse response) { 208 if (logger.isTraceEnabled()) { 209 logger.trace("Handshake response headers: " + response.getHeaders()); 210 } 211 } 212 } 213 214}