001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.client.standard;
018
019import java.net.InetAddress;
020import java.net.InetSocketAddress;
021import java.net.URI;
022import java.net.UnknownHostException;
023import java.util.ArrayList;
024import java.util.HashMap;
025import java.util.List;
026import java.util.Locale;
027import java.util.Map;
028import java.util.concurrent.Callable;
029
030import javax.websocket.ClientEndpointConfig;
031import javax.websocket.ClientEndpointConfig.Configurator;
032import javax.websocket.ContainerProvider;
033import javax.websocket.Endpoint;
034import javax.websocket.Extension;
035import javax.websocket.HandshakeResponse;
036import javax.websocket.WebSocketContainer;
037
038import org.springframework.core.task.AsyncListenableTaskExecutor;
039import org.springframework.core.task.SimpleAsyncTaskExecutor;
040import org.springframework.core.task.TaskExecutor;
041import org.springframework.http.HttpHeaders;
042import org.springframework.lang.Nullable;
043import org.springframework.util.Assert;
044import org.springframework.util.concurrent.ListenableFuture;
045import org.springframework.util.concurrent.ListenableFutureTask;
046import org.springframework.web.socket.WebSocketExtension;
047import org.springframework.web.socket.WebSocketHandler;
048import org.springframework.web.socket.WebSocketSession;
049import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
050import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
051import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
052import org.springframework.web.socket.client.AbstractWebSocketClient;
053
054/**
055 * A WebSocketClient based on standard Java WebSocket API.
056 *
057 * @author Rossen Stoyanchev
058 * @since 4.0
059 */
060public class StandardWebSocketClient extends AbstractWebSocketClient {
061
062        private final WebSocketContainer webSocketContainer;
063
064        private final Map<String,Object> userProperties = new HashMap<>();
065
066        @Nullable
067        private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();
068
069
070        /**
071         * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()}
072         * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor
073         * accepting existing {@code WebSocketContainer} instance.
074         */
075        public StandardWebSocketClient() {
076                this.webSocketContainer = ContainerProvider.getWebSocketContainer();
077        }
078
079        /**
080         * Constructor accepting an existing {@link WebSocketContainer} instance.
081         * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java
082         * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain
083         * the {@code WebSocketContainer} instance.
084         */
085        public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
086                Assert.notNull(webSocketContainer, "WebSocketContainer must not be null");
087                this.webSocketContainer = webSocketContainer;
088        }
089
090
091        /**
092         * The standard Java WebSocket API allows passing "user properties" to the
093         * server via {@link ClientEndpointConfig#getUserProperties() userProperties}.
094         * Use this property to configure one or more properties to be passed on
095         * every handshake.
096         */
097        public void setUserProperties(@Nullable Map<String, Object> userProperties) {
098                if (userProperties != null) {
099                        this.userProperties.putAll(userProperties);
100                }
101        }
102
103        /**
104         * The configured user properties.
105         */
106        public Map<String, Object> getUserProperties() {
107                return this.userProperties;
108        }
109
110        /**
111         * Set an {@link AsyncListenableTaskExecutor} to use when opening connections.
112         * If this property is set to {@code null}, calls to any of the
113         * {@code doHandshake} methods will block until the connection is established.
114         * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used.
115         */
116        public void setTaskExecutor(@Nullable AsyncListenableTaskExecutor taskExecutor) {
117                this.taskExecutor = taskExecutor;
118        }
119
120        /**
121         * Return the configured {@link TaskExecutor}.
122         */
123        @Nullable
124        public AsyncListenableTaskExecutor getTaskExecutor() {
125                return this.taskExecutor;
126        }
127
128
129        @Override
130        protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
131                        HttpHeaders headers, final URI uri, List<String> protocols,
132                        List<WebSocketExtension> extensions, Map<String, Object> attributes) {
133
134                int port = getPort(uri);
135                InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
136                InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
137
138                final StandardWebSocketSession session = new StandardWebSocketSession(headers,
139                                attributes, localAddress, remoteAddress);
140
141                final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
142                                .configurator(new StandardWebSocketClientConfigurator(headers))
143                                .preferredSubprotocols(protocols)
144                                .extensions(adaptExtensions(extensions)).build();
145
146                endpointConfig.getUserProperties().putAll(getUserProperties());
147
148                final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
149
150                Callable<WebSocketSession> connectTask = () -> {
151                        this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
152                        return session;
153                };
154
155                if (this.taskExecutor != null) {
156                        return this.taskExecutor.submitListenable(connectTask);
157                }
158                else {
159                        ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<>(connectTask);
160                        task.run();
161                        return task;
162                }
163        }
164
165        private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) {
166                List<Extension> result = new ArrayList<>();
167                for (WebSocketExtension extension : extensions) {
168                        result.add(new WebSocketToStandardExtensionAdapter(extension));
169                }
170                return result;
171        }
172
173        private InetAddress getLocalHost() {
174                try {
175                        return InetAddress.getLocalHost();
176                }
177                catch (UnknownHostException ex) {
178                        return InetAddress.getLoopbackAddress();
179                }
180        }
181
182        private int getPort(URI uri) {
183                if (uri.getPort() == -1) {
184                        String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
185                        return ("wss".equals(scheme) ? 443 : 80);
186                }
187                return uri.getPort();
188        }
189
190
191        private class StandardWebSocketClientConfigurator extends Configurator {
192
193                private final HttpHeaders headers;
194
195                public StandardWebSocketClientConfigurator(HttpHeaders headers) {
196                        this.headers = headers;
197                }
198
199                @Override
200                public void beforeRequest(Map<String, List<String>> requestHeaders) {
201                        requestHeaders.putAll(this.headers);
202                        if (logger.isTraceEnabled()) {
203                                logger.trace("Handshake request headers: " + requestHeaders);
204                        }
205                }
206                @Override
207                public void afterResponse(HandshakeResponse response) {
208                        if (logger.isTraceEnabled()) {
209                                logger.trace("Handshake response headers: " + response.getHeaders());
210                        }
211                }
212        }
213
214}