001/* 002 * Copyright 2002-2016 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.client.standard; 018 019import java.net.InetAddress; 020import java.net.InetSocketAddress; 021import java.net.URI; 022import java.net.UnknownHostException; 023import java.util.ArrayList; 024import java.util.HashMap; 025import java.util.List; 026import java.util.Locale; 027import java.util.Map; 028import java.util.concurrent.Callable; 029import javax.websocket.ClientEndpointConfig; 030import javax.websocket.ClientEndpointConfig.Configurator; 031import javax.websocket.ContainerProvider; 032import javax.websocket.Endpoint; 033import javax.websocket.Extension; 034import javax.websocket.HandshakeResponse; 035import javax.websocket.WebSocketContainer; 036 037import org.springframework.core.task.AsyncListenableTaskExecutor; 038import org.springframework.core.task.SimpleAsyncTaskExecutor; 039import org.springframework.core.task.TaskExecutor; 040import org.springframework.http.HttpHeaders; 041import org.springframework.lang.UsesJava7; 042import org.springframework.util.Assert; 043import org.springframework.util.concurrent.ListenableFuture; 044import org.springframework.util.concurrent.ListenableFutureTask; 045import org.springframework.web.socket.WebSocketExtension; 046import org.springframework.web.socket.WebSocketHandler; 047import org.springframework.web.socket.WebSocketSession; 048import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; 049import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; 050import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; 051import org.springframework.web.socket.client.AbstractWebSocketClient; 052 053/** 054 * A WebSocketClient based on standard Java WebSocket API. 055 * 056 * @author Rossen Stoyanchev 057 * @since 4.0 058 */ 059public class StandardWebSocketClient extends AbstractWebSocketClient { 060 061 private final WebSocketContainer webSocketContainer; 062 063 private final Map<String,Object> userProperties = new HashMap<String, Object>(); 064 065 private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); 066 067 068 /** 069 * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()} 070 * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor 071 * accepting existing {@code WebSocketContainer} instance. 072 */ 073 public StandardWebSocketClient() { 074 this.webSocketContainer = ContainerProvider.getWebSocketContainer(); 075 } 076 077 /** 078 * Constructor accepting an existing {@link WebSocketContainer} instance. 079 * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java 080 * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain 081 * the {@code WebSocketContainer} instance. 082 */ 083 public StandardWebSocketClient(WebSocketContainer webSocketContainer) { 084 Assert.notNull(webSocketContainer, "WebSocketContainer must not be null"); 085 this.webSocketContainer = webSocketContainer; 086 } 087 088 089 /** 090 * The standard Java WebSocket API allows passing "user properties" to the 091 * server via {@link ClientEndpointConfig#getUserProperties() userProperties}. 092 * Use this property to configure one or more properties to be passed on 093 * every handshake. 094 */ 095 public void setUserProperties(Map<String, Object> userProperties) { 096 if (userProperties != null) { 097 this.userProperties.putAll(userProperties); 098 } 099 } 100 101 /** 102 * The configured user properties, or {@code null}. 103 */ 104 public Map<String, Object> getUserProperties() { 105 return this.userProperties; 106 } 107 108 /** 109 * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. 110 * If this property is set to {@code null}, calls to any of the 111 * {@code doHandshake} methods will block until the connection is established. 112 * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used. 113 */ 114 public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { 115 this.taskExecutor = taskExecutor; 116 } 117 118 /** 119 * Return the configured {@link TaskExecutor}. 120 */ 121 public AsyncListenableTaskExecutor getTaskExecutor() { 122 return this.taskExecutor; 123 } 124 125 126 @Override 127 protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler, 128 HttpHeaders headers, final URI uri, List<String> protocols, 129 List<WebSocketExtension> extensions, Map<String, Object> attributes) { 130 131 int port = getPort(uri); 132 InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); 133 InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); 134 135 final StandardWebSocketSession session = new StandardWebSocketSession(headers, 136 attributes, localAddress, remoteAddress); 137 138 final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() 139 .configurator(new StandardWebSocketClientConfigurator(headers)) 140 .preferredSubprotocols(protocols) 141 .extensions(adaptExtensions(extensions)).build(); 142 143 endpointConfig.getUserProperties().putAll(getUserProperties()); 144 145 final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); 146 147 Callable<WebSocketSession> connectTask = new Callable<WebSocketSession>() { 148 @Override 149 public WebSocketSession call() throws Exception { 150 webSocketContainer.connectToServer(endpoint, endpointConfig, uri); 151 return session; 152 } 153 }; 154 155 if (this.taskExecutor != null) { 156 return this.taskExecutor.submitListenable(connectTask); 157 } 158 else { 159 ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<WebSocketSession>(connectTask); 160 task.run(); 161 return task; 162 } 163 } 164 165 private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) { 166 List<Extension> result = new ArrayList<Extension>(); 167 for (WebSocketExtension extension : extensions) { 168 result.add(new WebSocketToStandardExtensionAdapter(extension)); 169 } 170 return result; 171 } 172 173 @UsesJava7 // fallback to InetAddress.getLoopbackAddress() 174 private InetAddress getLocalHost() { 175 try { 176 return InetAddress.getLocalHost(); 177 } 178 catch (UnknownHostException ex) { 179 return InetAddress.getLoopbackAddress(); 180 } 181 } 182 183 private int getPort(URI uri) { 184 if (uri.getPort() == -1) { 185 String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH); 186 return ("wss".equals(scheme) ? 443 : 80); 187 } 188 return uri.getPort(); 189 } 190 191 192 private class StandardWebSocketClientConfigurator extends Configurator { 193 194 private final HttpHeaders headers; 195 196 public StandardWebSocketClientConfigurator(HttpHeaders headers) { 197 this.headers = headers; 198 } 199 200 @Override 201 public void beforeRequest(Map<String, List<String>> requestHeaders) { 202 requestHeaders.putAll(this.headers); 203 if (logger.isTraceEnabled()) { 204 logger.trace("Handshake request headers: " + requestHeaders); 205 } 206 } 207 @Override 208 public void afterResponse(HandshakeResponse response) { 209 if (logger.isTraceEnabled()) { 210 logger.trace("Handshake response headers: " + response.getHeaders()); 211 } 212 } 213 } 214 215}