001/*
002 * Copyright 2002-2016 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.client.standard;
018
019import java.net.InetAddress;
020import java.net.InetSocketAddress;
021import java.net.URI;
022import java.net.UnknownHostException;
023import java.util.ArrayList;
024import java.util.HashMap;
025import java.util.List;
026import java.util.Locale;
027import java.util.Map;
028import java.util.concurrent.Callable;
029import javax.websocket.ClientEndpointConfig;
030import javax.websocket.ClientEndpointConfig.Configurator;
031import javax.websocket.ContainerProvider;
032import javax.websocket.Endpoint;
033import javax.websocket.Extension;
034import javax.websocket.HandshakeResponse;
035import javax.websocket.WebSocketContainer;
036
037import org.springframework.core.task.AsyncListenableTaskExecutor;
038import org.springframework.core.task.SimpleAsyncTaskExecutor;
039import org.springframework.core.task.TaskExecutor;
040import org.springframework.http.HttpHeaders;
041import org.springframework.lang.UsesJava7;
042import org.springframework.util.Assert;
043import org.springframework.util.concurrent.ListenableFuture;
044import org.springframework.util.concurrent.ListenableFutureTask;
045import org.springframework.web.socket.WebSocketExtension;
046import org.springframework.web.socket.WebSocketHandler;
047import org.springframework.web.socket.WebSocketSession;
048import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
049import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
050import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
051import org.springframework.web.socket.client.AbstractWebSocketClient;
052
053/**
054 * A WebSocketClient based on standard Java WebSocket API.
055 *
056 * @author Rossen Stoyanchev
057 * @since 4.0
058 */
059public class StandardWebSocketClient extends AbstractWebSocketClient {
060
061        private final WebSocketContainer webSocketContainer;
062
063        private final Map<String,Object> userProperties = new HashMap<String, Object>();
064
065        private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
066
067
068        /**
069         * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()}
070         * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor
071         * accepting existing {@code WebSocketContainer} instance.
072         */
073        public StandardWebSocketClient() {
074                this.webSocketContainer = ContainerProvider.getWebSocketContainer();
075        }
076
077        /**
078         * Constructor accepting an existing {@link WebSocketContainer} instance.
079         * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java
080         * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain
081         * the {@code WebSocketContainer} instance.
082         */
083        public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
084                Assert.notNull(webSocketContainer, "WebSocketContainer must not be null");
085                this.webSocketContainer = webSocketContainer;
086        }
087
088
089        /**
090         * The standard Java WebSocket API allows passing "user properties" to the
091         * server via {@link ClientEndpointConfig#getUserProperties() userProperties}.
092         * Use this property to configure one or more properties to be passed on
093         * every handshake.
094         */
095        public void setUserProperties(Map<String, Object> userProperties) {
096                if (userProperties != null) {
097                        this.userProperties.putAll(userProperties);
098                }
099        }
100
101        /**
102         * The configured user properties, or {@code null}.
103         */
104        public Map<String, Object> getUserProperties() {
105                return this.userProperties;
106        }
107
108        /**
109         * Set an {@link AsyncListenableTaskExecutor} to use when opening connections.
110         * If this property is set to {@code null}, calls to  any of the
111         * {@code doHandshake} methods will block until the connection is established.
112         * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used.
113         */
114        public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
115                this.taskExecutor = taskExecutor;
116        }
117
118        /**
119         * Return the configured {@link TaskExecutor}.
120         */
121        public AsyncListenableTaskExecutor getTaskExecutor() {
122                return this.taskExecutor;
123        }
124
125
126        @Override
127        protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
128                        HttpHeaders headers, final URI uri, List<String> protocols,
129                        List<WebSocketExtension> extensions, Map<String, Object> attributes) {
130
131                int port = getPort(uri);
132                InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
133                InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
134
135                final StandardWebSocketSession session = new StandardWebSocketSession(headers,
136                                attributes, localAddress, remoteAddress);
137
138                final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
139                                .configurator(new StandardWebSocketClientConfigurator(headers))
140                                .preferredSubprotocols(protocols)
141                                .extensions(adaptExtensions(extensions)).build();
142
143                endpointConfig.getUserProperties().putAll(getUserProperties());
144
145                final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
146
147                Callable<WebSocketSession> connectTask = new Callable<WebSocketSession>() {
148                        @Override
149                        public WebSocketSession call() throws Exception {
150                                webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
151                                return session;
152                        }
153                };
154
155                if (this.taskExecutor != null) {
156                        return this.taskExecutor.submitListenable(connectTask);
157                }
158                else {
159                        ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<WebSocketSession>(connectTask);
160                        task.run();
161                        return task;
162                }
163        }
164
165        private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) {
166                List<Extension> result = new ArrayList<Extension>();
167                for (WebSocketExtension extension : extensions) {
168                        result.add(new WebSocketToStandardExtensionAdapter(extension));
169                }
170                return result;
171        }
172
173        @UsesJava7  // fallback to InetAddress.getLoopbackAddress()
174        private InetAddress getLocalHost() {
175                try {
176                        return InetAddress.getLocalHost();
177                }
178                catch (UnknownHostException ex) {
179                        return InetAddress.getLoopbackAddress();
180                }
181        }
182
183        private int getPort(URI uri) {
184                if (uri.getPort() == -1) {
185                String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
186                return ("wss".equals(scheme) ? 443 : 80);
187                }
188                return uri.getPort();
189        }
190
191
192        private class StandardWebSocketClientConfigurator extends Configurator {
193
194                private final HttpHeaders headers;
195
196                public StandardWebSocketClientConfigurator(HttpHeaders headers) {
197                        this.headers = headers;
198                }
199
200                @Override
201                public void beforeRequest(Map<String, List<String>> requestHeaders) {
202                        requestHeaders.putAll(this.headers);
203                        if (logger.isTraceEnabled()) {
204                                logger.trace("Handshake request headers: " + requestHeaders);
205                        }
206                }
207                @Override
208                public void afterResponse(HandshakeResponse response) {
209                        if (logger.isTraceEnabled()) {
210                                logger.trace("Handshake response headers: " + response.getHeaders());
211                        }
212                }
213        }
214
215}