001/* 002 * Copyright 2002-2019 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.reactive.socket.client; 018 019import java.io.IOException; 020import java.net.URI; 021import java.util.Collections; 022import java.util.List; 023import java.util.Map; 024import java.util.function.Consumer; 025 026import io.undertow.connector.ByteBufferPool; 027import io.undertow.server.DefaultByteBufferPool; 028import io.undertow.websockets.client.WebSocketClient.ConnectionBuilder; 029import io.undertow.websockets.client.WebSocketClientNegotiation; 030import io.undertow.websockets.core.WebSocketChannel; 031import org.apache.commons.logging.Log; 032import org.apache.commons.logging.LogFactory; 033import org.xnio.IoFuture; 034import org.xnio.XnioWorker; 035import reactor.core.publisher.Mono; 036import reactor.core.publisher.MonoProcessor; 037 038import org.springframework.core.io.buffer.DataBufferFactory; 039import org.springframework.core.io.buffer.DefaultDataBufferFactory; 040import org.springframework.http.HttpHeaders; 041import org.springframework.lang.Nullable; 042import org.springframework.util.Assert; 043import org.springframework.web.reactive.socket.HandshakeInfo; 044import org.springframework.web.reactive.socket.WebSocketHandler; 045import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; 046import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; 047 048/** 049 * Undertow based implementation of {@link WebSocketClient}. 050 * 051 * @author Violeta Georgieva 052 * @author Rossen Stoyanchev 053 * @since 5.0 054 */ 055public class UndertowWebSocketClient implements WebSocketClient { 056 057 private static final Log logger = LogFactory.getLog(UndertowWebSocketClient.class); 058 059 private static final int DEFAULT_POOL_BUFFER_SIZE = 8192; 060 061 062 private final XnioWorker worker; 063 064 private ByteBufferPool byteBufferPool; 065 066 private final Consumer<ConnectionBuilder> builderConsumer; 067 068 private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); 069 070 071 /** 072 * Constructor with the {@link XnioWorker} to pass to 073 * {@link io.undertow.websockets.client.WebSocketClient#connectionBuilder}. 074 * @param worker the Xnio worker 075 */ 076 public UndertowWebSocketClient(XnioWorker worker) { 077 this(worker, builder -> {}); 078 } 079 080 /** 081 * Alternate constructor providing additional control over the 082 * {@link ConnectionBuilder} for each WebSocket connection. 083 * @param worker the Xnio worker to use to create {@code ConnectionBuilder}'s 084 * @param builderConsumer a consumer to configure {@code ConnectionBuilder}'s 085 */ 086 public UndertowWebSocketClient(XnioWorker worker, Consumer<ConnectionBuilder> builderConsumer) { 087 this(worker, new DefaultByteBufferPool(false, DEFAULT_POOL_BUFFER_SIZE), builderConsumer); 088 } 089 090 /** 091 * Alternate constructor providing additional control over the 092 * {@link ConnectionBuilder} for each WebSocket connection. 093 * @param worker the Xnio worker to use to create {@code ConnectionBuilder}'s 094 * @param byteBufferPool the ByteBufferPool to use to create {@code ConnectionBuilder}'s 095 * @param builderConsumer a consumer to configure {@code ConnectionBuilder}'s 096 * @since 5.0.8 097 */ 098 public UndertowWebSocketClient(XnioWorker worker, ByteBufferPool byteBufferPool, 099 Consumer<ConnectionBuilder> builderConsumer) { 100 101 Assert.notNull(worker, "XnioWorker must not be null"); 102 Assert.notNull(byteBufferPool, "ByteBufferPool must not be null"); 103 this.worker = worker; 104 this.byteBufferPool = byteBufferPool; 105 this.builderConsumer = builderConsumer; 106 } 107 108 109 /** 110 * Return the configured {@link XnioWorker}. 111 */ 112 public XnioWorker getXnioWorker() { 113 return this.worker; 114 } 115 116 /** 117 * Set the {@link io.undertow.connector.ByteBufferPool ByteBufferPool} to pass to 118 * {@link io.undertow.websockets.client.WebSocketClient#connectionBuilder}. 119 * <p>By default an indirect {@link io.undertow.server.DefaultByteBufferPool} 120 * with a buffer size of 8192 is used. 121 * @since 5.0.8 122 * @see #DEFAULT_POOL_BUFFER_SIZE 123 */ 124 public void setByteBufferPool(ByteBufferPool byteBufferPool) { 125 Assert.notNull(byteBufferPool, "ByteBufferPool must not be null"); 126 this.byteBufferPool = byteBufferPool; 127 } 128 129 /** 130 * Return the {@link io.undertow.connector.ByteBufferPool} currently used 131 * for newly created WebSocket sessions by this client. 132 * @return the byte buffer pool 133 * @since 5.0.8 134 */ 135 public ByteBufferPool getByteBufferPool() { 136 return this.byteBufferPool; 137 } 138 139 /** 140 * Return the configured <code>Consumer<ConnectionBuilder></code>. 141 */ 142 public Consumer<ConnectionBuilder> getConnectionBuilderConsumer() { 143 return this.builderConsumer; 144 } 145 146 147 @Override 148 public Mono<Void> execute(URI url, WebSocketHandler handler) { 149 return execute(url, new HttpHeaders(), handler); 150 } 151 152 @Override 153 public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) { 154 return executeInternal(url, headers, handler); 155 } 156 157 private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { 158 MonoProcessor<Void> completion = MonoProcessor.create(); 159 return Mono.fromCallable( 160 () -> { 161 if (logger.isDebugEnabled()) { 162 logger.debug("Connecting to " + url); 163 } 164 List<String> protocols = handler.getSubProtocols(); 165 ConnectionBuilder builder = createConnectionBuilder(url); 166 DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder); 167 builder.setClientNegotiation(negotiation); 168 return builder.connect().addNotifier( 169 new IoFuture.HandlingNotifier<WebSocketChannel, Object>() { 170 @Override 171 public void handleDone(WebSocketChannel channel, Object attachment) { 172 handleChannel(url, handler, completion, negotiation, channel); 173 } 174 @Override 175 public void handleFailed(IOException ex, Object attachment) { 176 completion.onError(new IllegalStateException("Failed to connect to " + url, ex)); 177 } 178 }, null); 179 }) 180 .then(completion); 181 } 182 183 /** 184 * Create a {@link ConnectionBuilder} for the given URI. 185 * <p>The default implementation creates a builder with the configured 186 * {@link #getXnioWorker() XnioWorker} and {@link #getByteBufferPool() ByteBufferPool} and 187 * then passes it to the {@link #getConnectionBuilderConsumer() consumer} 188 * provided at construction time. 189 */ 190 protected ConnectionBuilder createConnectionBuilder(URI url) { 191 ConnectionBuilder builder = io.undertow.websockets.client.WebSocketClient 192 .connectionBuilder(getXnioWorker(), getByteBufferPool(), url); 193 this.builderConsumer.accept(builder); 194 return builder; 195 } 196 197 private void handleChannel(URI url, WebSocketHandler handler, MonoProcessor<Void> completion, 198 DefaultNegotiation negotiation, WebSocketChannel channel) { 199 200 HandshakeInfo info = createHandshakeInfo(url, negotiation); 201 UndertowWebSocketSession session = new UndertowWebSocketSession(channel, info, this.bufferFactory, completion); 202 UndertowWebSocketHandlerAdapter adapter = new UndertowWebSocketHandlerAdapter(session); 203 204 channel.getReceiveSetter().set(adapter); 205 channel.resumeReceives(); 206 207 handler.handle(session) 208 .checkpoint(url + " [UndertowWebSocketClient]") 209 .subscribe(session); 210 } 211 212 private HandshakeInfo createHandshakeInfo(URI url, DefaultNegotiation negotiation) { 213 HttpHeaders responseHeaders = negotiation.getResponseHeaders(); 214 String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol"); 215 return new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol); 216 } 217 218 219 private static final class DefaultNegotiation extends WebSocketClientNegotiation { 220 221 private final HttpHeaders requestHeaders; 222 223 private final HttpHeaders responseHeaders = new HttpHeaders(); 224 225 @Nullable 226 private final WebSocketClientNegotiation delegate; 227 228 public DefaultNegotiation(List<String> protocols, HttpHeaders requestHeaders, 229 ConnectionBuilder connectionBuilder) { 230 231 super(protocols, Collections.emptyList()); 232 this.requestHeaders = requestHeaders; 233 this.delegate = connectionBuilder.getClientNegotiation(); 234 } 235 236 public HttpHeaders getResponseHeaders() { 237 return this.responseHeaders; 238 } 239 240 @Override 241 public void beforeRequest(Map<String, List<String>> headers) { 242 this.requestHeaders.forEach(headers::put); 243 if (this.delegate != null) { 244 this.delegate.beforeRequest(headers); 245 } 246 } 247 248 @Override 249 public void afterRequest(Map<String, List<String>> headers) { 250 headers.forEach(this.responseHeaders::put); 251 if (this.delegate != null) { 252 this.delegate.afterRequest(headers); 253 } 254 } 255 } 256 257}