001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.client;
018
019import java.io.IOException;
020import java.net.URI;
021import java.util.Collections;
022import java.util.List;
023import java.util.Map;
024import java.util.function.Consumer;
025
026import io.undertow.connector.ByteBufferPool;
027import io.undertow.server.DefaultByteBufferPool;
028import io.undertow.websockets.client.WebSocketClient.ConnectionBuilder;
029import io.undertow.websockets.client.WebSocketClientNegotiation;
030import io.undertow.websockets.core.WebSocketChannel;
031import org.apache.commons.logging.Log;
032import org.apache.commons.logging.LogFactory;
033import org.xnio.IoFuture;
034import org.xnio.XnioWorker;
035import reactor.core.publisher.Mono;
036import reactor.core.publisher.MonoProcessor;
037
038import org.springframework.core.io.buffer.DataBufferFactory;
039import org.springframework.core.io.buffer.DefaultDataBufferFactory;
040import org.springframework.http.HttpHeaders;
041import org.springframework.lang.Nullable;
042import org.springframework.util.Assert;
043import org.springframework.web.reactive.socket.HandshakeInfo;
044import org.springframework.web.reactive.socket.WebSocketHandler;
045import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
046import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
047
048/**
049 * Undertow based implementation of {@link WebSocketClient}.
050 *
051 * @author Violeta Georgieva
052 * @author Rossen Stoyanchev
053 * @since 5.0
054 */
055public class UndertowWebSocketClient implements WebSocketClient {
056
057        private static final Log logger = LogFactory.getLog(UndertowWebSocketClient.class);
058
059        private static final int DEFAULT_POOL_BUFFER_SIZE = 8192;
060
061
062        private final XnioWorker worker;
063
064        private ByteBufferPool byteBufferPool;
065
066        private final Consumer<ConnectionBuilder> builderConsumer;
067
068        private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
069
070
071        /**
072         * Constructor with the {@link XnioWorker} to pass to
073         * {@link io.undertow.websockets.client.WebSocketClient#connectionBuilder}.
074         * @param worker the Xnio worker
075         */
076        public UndertowWebSocketClient(XnioWorker worker) {
077                this(worker, builder -> {});
078        }
079
080        /**
081         * Alternate constructor providing additional control over the
082         * {@link ConnectionBuilder} for each WebSocket connection.
083         * @param worker the Xnio worker to use to create {@code ConnectionBuilder}'s
084         * @param builderConsumer a consumer to configure {@code ConnectionBuilder}'s
085         */
086        public UndertowWebSocketClient(XnioWorker worker, Consumer<ConnectionBuilder> builderConsumer) {
087                this(worker, new DefaultByteBufferPool(false, DEFAULT_POOL_BUFFER_SIZE), builderConsumer);
088        }
089
090        /**
091         * Alternate constructor providing additional control over the
092         * {@link ConnectionBuilder} for each WebSocket connection.
093         * @param worker the Xnio worker to use to create {@code ConnectionBuilder}'s
094         * @param byteBufferPool the ByteBufferPool to use to create {@code ConnectionBuilder}'s
095         * @param builderConsumer a consumer to configure {@code ConnectionBuilder}'s
096         * @since 5.0.8
097         */
098        public UndertowWebSocketClient(XnioWorker worker, ByteBufferPool byteBufferPool,
099                        Consumer<ConnectionBuilder> builderConsumer) {
100
101                Assert.notNull(worker, "XnioWorker must not be null");
102                Assert.notNull(byteBufferPool, "ByteBufferPool must not be null");
103                this.worker = worker;
104                this.byteBufferPool = byteBufferPool;
105                this.builderConsumer = builderConsumer;
106        }
107
108
109        /**
110         * Return the configured {@link XnioWorker}.
111         */
112        public XnioWorker getXnioWorker() {
113                return this.worker;
114        }
115
116        /**
117         * Set the {@link io.undertow.connector.ByteBufferPool ByteBufferPool} to pass to
118         * {@link io.undertow.websockets.client.WebSocketClient#connectionBuilder}.
119         * <p>By default an indirect {@link io.undertow.server.DefaultByteBufferPool}
120         * with a buffer size of 8192 is used.
121         * @since 5.0.8
122         * @see #DEFAULT_POOL_BUFFER_SIZE
123         */
124        public void setByteBufferPool(ByteBufferPool byteBufferPool) {
125                Assert.notNull(byteBufferPool, "ByteBufferPool must not be null");
126                this.byteBufferPool = byteBufferPool;
127        }
128
129        /**
130         * Return the {@link io.undertow.connector.ByteBufferPool} currently used
131         * for newly created WebSocket sessions by this client.
132         * @return the byte buffer pool
133         * @since 5.0.8
134         */
135        public ByteBufferPool getByteBufferPool() {
136                return this.byteBufferPool;
137        }
138
139        /**
140         * Return the configured <code>Consumer&lt;ConnectionBuilder&gt;</code>.
141         */
142        public Consumer<ConnectionBuilder> getConnectionBuilderConsumer() {
143                return this.builderConsumer;
144        }
145
146
147        @Override
148        public Mono<Void> execute(URI url, WebSocketHandler handler) {
149                return execute(url, new HttpHeaders(), handler);
150        }
151
152        @Override
153        public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
154                return executeInternal(url, headers, handler);
155        }
156
157        private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
158                MonoProcessor<Void> completion = MonoProcessor.create();
159                return Mono.fromCallable(
160                                () -> {
161                                        if (logger.isDebugEnabled()) {
162                                                logger.debug("Connecting to " + url);
163                                        }
164                                        List<String> protocols = handler.getSubProtocols();
165                                        ConnectionBuilder builder = createConnectionBuilder(url);
166                                        DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder);
167                                        builder.setClientNegotiation(negotiation);
168                                        return builder.connect().addNotifier(
169                                                        new IoFuture.HandlingNotifier<WebSocketChannel, Object>() {
170                                                                @Override
171                                                                public void handleDone(WebSocketChannel channel, Object attachment) {
172                                                                        handleChannel(url, handler, completion, negotiation, channel);
173                                                                }
174                                                                @Override
175                                                                public void handleFailed(IOException ex, Object attachment) {
176                                                                        completion.onError(new IllegalStateException("Failed to connect to " + url, ex));
177                                                                }
178                                                        }, null);
179                                })
180                                .then(completion);
181        }
182
183        /**
184         * Create a {@link ConnectionBuilder} for the given URI.
185         * <p>The default implementation creates a builder with the configured
186         * {@link #getXnioWorker() XnioWorker} and {@link #getByteBufferPool() ByteBufferPool} and
187         * then passes it to the {@link #getConnectionBuilderConsumer() consumer}
188         * provided at construction time.
189         */
190        protected ConnectionBuilder createConnectionBuilder(URI url) {
191                ConnectionBuilder builder = io.undertow.websockets.client.WebSocketClient
192                                .connectionBuilder(getXnioWorker(), getByteBufferPool(), url);
193                this.builderConsumer.accept(builder);
194                return builder;
195        }
196
197        private void handleChannel(URI url, WebSocketHandler handler, MonoProcessor<Void> completion,
198                        DefaultNegotiation negotiation, WebSocketChannel channel) {
199
200                HandshakeInfo info = createHandshakeInfo(url, negotiation);
201                UndertowWebSocketSession session = new UndertowWebSocketSession(channel, info, this.bufferFactory, completion);
202                UndertowWebSocketHandlerAdapter adapter = new UndertowWebSocketHandlerAdapter(session);
203
204                channel.getReceiveSetter().set(adapter);
205                channel.resumeReceives();
206
207                handler.handle(session)
208                                .checkpoint(url + " [UndertowWebSocketClient]")
209                                .subscribe(session);
210        }
211
212        private HandshakeInfo createHandshakeInfo(URI url, DefaultNegotiation negotiation) {
213                HttpHeaders responseHeaders = negotiation.getResponseHeaders();
214                String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol");
215                return new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
216        }
217
218
219        private static final class DefaultNegotiation extends WebSocketClientNegotiation {
220
221                private final HttpHeaders requestHeaders;
222
223                private final HttpHeaders responseHeaders = new HttpHeaders();
224
225                @Nullable
226                private final WebSocketClientNegotiation delegate;
227
228                public DefaultNegotiation(List<String> protocols, HttpHeaders requestHeaders,
229                                ConnectionBuilder connectionBuilder) {
230
231                        super(protocols, Collections.emptyList());
232                        this.requestHeaders = requestHeaders;
233                        this.delegate = connectionBuilder.getClientNegotiation();
234                }
235
236                public HttpHeaders getResponseHeaders() {
237                        return this.responseHeaders;
238                }
239
240                @Override
241                public void beforeRequest(Map<String, List<String>> headers) {
242                        this.requestHeaders.forEach(headers::put);
243                        if (this.delegate != null) {
244                                this.delegate.beforeRequest(headers);
245                        }
246                }
247
248                @Override
249                public void afterRequest(Map<String, List<String>> headers) {
250                        headers.forEach(this.responseHeaders::put);
251                        if (this.delegate != null) {
252                                this.delegate.afterRequest(headers);
253                        }
254                }
255        }
256
257}