001/* 002 * Copyright 2002-2020 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.sockjs.client; 018 019import java.io.ByteArrayOutputStream; 020import java.io.IOException; 021import java.net.URI; 022import java.nio.ByteBuffer; 023import java.util.List; 024import java.util.concurrent.CopyOnWriteArrayList; 025import java.util.concurrent.CountDownLatch; 026 027import io.undertow.client.ClientCallback; 028import io.undertow.client.ClientConnection; 029import io.undertow.client.ClientExchange; 030import io.undertow.client.ClientRequest; 031import io.undertow.client.ClientResponse; 032import io.undertow.client.UndertowClient; 033import io.undertow.connector.ByteBufferPool; 034import io.undertow.connector.PooledByteBuffer; 035import io.undertow.server.DefaultByteBufferPool; 036import io.undertow.util.AttachmentKey; 037import io.undertow.util.HeaderMap; 038import io.undertow.util.HttpString; 039import io.undertow.util.Methods; 040import io.undertow.util.StringReadChannelListener; 041import org.xnio.ChannelListener; 042import org.xnio.ChannelListeners; 043import org.xnio.IoUtils; 044import org.xnio.OptionMap; 045import org.xnio.Options; 046import org.xnio.Xnio; 047import org.xnio.XnioWorker; 048import org.xnio.channels.StreamSinkChannel; 049import org.xnio.channels.StreamSourceChannel; 050 051import org.springframework.http.HttpHeaders; 052import org.springframework.http.HttpStatus; 053import org.springframework.http.ResponseEntity; 054import org.springframework.lang.Nullable; 055import org.springframework.util.Assert; 056import org.springframework.util.StreamUtils; 057import org.springframework.util.StringUtils; 058import org.springframework.util.concurrent.SettableListenableFuture; 059import org.springframework.web.client.HttpServerErrorException; 060import org.springframework.web.socket.CloseStatus; 061import org.springframework.web.socket.TextMessage; 062import org.springframework.web.socket.WebSocketHandler; 063import org.springframework.web.socket.WebSocketSession; 064import org.springframework.web.socket.sockjs.SockJsException; 065import org.springframework.web.socket.sockjs.SockJsTransportFailureException; 066import org.springframework.web.socket.sockjs.frame.SockJsFrame; 067 068/** 069 * An XHR transport based on Undertow's {@link io.undertow.client.UndertowClient}. 070 * Requires Undertow 1.3 or 1.4, including XNIO, as of Spring Framework 5.0. 071 * 072 * <p>When used for testing purposes (e.g. load testing) or for specific use cases 073 * (like HTTPS configuration), a custom OptionMap should be provided: 074 * 075 * <pre class="code"> 076 * OptionMap optionMap = OptionMap.builder() 077 * .set(Options.WORKER_IO_THREADS, 8) 078 * .set(Options.TCP_NODELAY, true) 079 * .set(Options.KEEP_ALIVE, true) 080 * .set(Options.WORKER_NAME, "SockJSClient") 081 * .getMap(); 082 * 083 * UndertowXhrTransport transport = new UndertowXhrTransport(optionMap); 084 * </pre> 085 * 086 * @author Brian Clozel 087 * @author Rossen Stoyanchev 088 * @since 4.1.2 089 * @see org.xnio.Options 090 */ 091public class UndertowXhrTransport extends AbstractXhrTransport { 092 093 private static final AttachmentKey<String> RESPONSE_BODY = AttachmentKey.create(String.class); 094 095 096 private final OptionMap optionMap; 097 098 private final UndertowClient httpClient; 099 100 private final XnioWorker worker; 101 102 private final ByteBufferPool bufferPool; 103 104 105 public UndertowXhrTransport() throws IOException { 106 this(OptionMap.builder().parse(Options.WORKER_NAME, "SockJSClient").getMap()); 107 } 108 109 public UndertowXhrTransport(OptionMap optionMap) throws IOException { 110 Assert.notNull(optionMap, "OptionMap is required"); 111 this.optionMap = optionMap; 112 this.httpClient = UndertowClient.getInstance(); 113 this.worker = Xnio.getInstance().createWorker(optionMap); 114 this.bufferPool = new DefaultByteBufferPool(false, 1024, -1, 2); 115 } 116 117 118 /** 119 * Return Undertow's native HTTP client. 120 */ 121 public UndertowClient getHttpClient() { 122 return this.httpClient; 123 } 124 125 /** 126 * Return the {@link org.xnio.XnioWorker} backing the I/O operations 127 * for Undertow's HTTP client. 128 * @see org.xnio.Xnio 129 */ 130 public XnioWorker getWorker() { 131 return this.worker; 132 } 133 134 135 @Override 136 protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl, 137 HttpHeaders handshakeHeaders, XhrClientSockJsSession session, 138 SettableListenableFuture<WebSocketSession> connectFuture) { 139 140 executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture); 141 } 142 143 private void executeReceiveRequest(final TransportRequest transportRequest, 144 final URI url, final HttpHeaders headers, final XhrClientSockJsSession session, 145 final SettableListenableFuture<WebSocketSession> connectFuture) { 146 147 if (logger.isTraceEnabled()) { 148 logger.trace("Starting XHR receive request for " + url); 149 } 150 151 ClientCallback<ClientConnection> clientCallback = new ClientCallback<ClientConnection>() { 152 @Override 153 public void completed(ClientConnection connection) { 154 ClientRequest request = new ClientRequest().setMethod(Methods.POST).setPath(url.getPath()); 155 HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST); 156 request.getRequestHeaders().add(headerName, url.getHost()); 157 addHttpHeaders(request, headers); 158 HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); 159 connection.sendRequest(request, createReceiveCallback(transportRequest, 160 url, httpHeaders, session, connectFuture)); 161 } 162 163 @Override 164 public void failed(IOException ex) { 165 throw new SockJsTransportFailureException("Failed to execute request to " + url, ex); 166 } 167 }; 168 169 this.httpClient.connect(clientCallback, url, this.worker, this.bufferPool, this.optionMap); 170 } 171 172 private static void addHttpHeaders(ClientRequest request, HttpHeaders headers) { 173 HeaderMap headerMap = request.getRequestHeaders(); 174 headers.forEach((key, values) -> { 175 for (String value : values) { 176 headerMap.add(HttpString.tryFromString(key), value); 177 } 178 }); 179 } 180 181 private ClientCallback<ClientExchange> createReceiveCallback(final TransportRequest transportRequest, 182 final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession, 183 final SettableListenableFuture<WebSocketSession> connectFuture) { 184 185 return new ClientCallback<ClientExchange>() { 186 @Override 187 public void completed(final ClientExchange exchange) { 188 exchange.setResponseListener(new ClientCallback<ClientExchange>() { 189 @Override 190 public void completed(ClientExchange result) { 191 ClientResponse response = result.getResponse(); 192 if (response.getResponseCode() != 200) { 193 HttpStatus status = HttpStatus.valueOf(response.getResponseCode()); 194 IoUtils.safeClose(result.getConnection()); 195 onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status")); 196 } 197 else { 198 SockJsResponseListener listener = new SockJsResponseListener( 199 transportRequest, result.getConnection(), url, headers, 200 sockJsSession, connectFuture); 201 listener.setup(result.getResponseChannel()); 202 } 203 if (logger.isTraceEnabled()) { 204 logger.trace("XHR receive headers: " + toHttpHeaders(response.getResponseHeaders())); 205 } 206 try { 207 StreamSinkChannel channel = result.getRequestChannel(); 208 channel.shutdownWrites(); 209 if (!channel.flush()) { 210 channel.getWriteSetter().set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null)); 211 channel.resumeWrites(); 212 } 213 } 214 catch (IOException exc) { 215 IoUtils.safeClose(result.getConnection()); 216 onFailure(exc); 217 } 218 } 219 220 @Override 221 public void failed(IOException exc) { 222 IoUtils.safeClose(exchange.getConnection()); 223 onFailure(exc); 224 } 225 }); 226 } 227 228 @Override 229 public void failed(IOException exc) { 230 onFailure(exc); 231 } 232 233 private void onFailure(Throwable failure) { 234 if (connectFuture.setException(failure)) { 235 return; 236 } 237 if (sockJsSession.isDisconnected()) { 238 sockJsSession.afterTransportClosed(null); 239 } 240 else { 241 sockJsSession.handleTransportError(failure); 242 sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage())); 243 } 244 } 245 }; 246 } 247 248 private static HttpHeaders toHttpHeaders(HeaderMap headerMap) { 249 HttpHeaders httpHeaders = new HttpHeaders(); 250 for (HttpString name : headerMap.getHeaderNames()) { 251 for (String value : headerMap.get(name)) { 252 httpHeaders.add(name.toString(), value); 253 } 254 } 255 return httpHeaders; 256 } 257 258 @Override 259 protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { 260 return executeRequest(infoUrl, Methods.GET, headers, null); 261 } 262 263 @Override 264 protected ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { 265 return executeRequest(url, Methods.POST, headers, message.getPayload()); 266 } 267 268 protected ResponseEntity<String> executeRequest( 269 URI url, HttpString method, HttpHeaders headers, @Nullable String body) { 270 271 CountDownLatch latch = new CountDownLatch(1); 272 List<ClientResponse> responses = new CopyOnWriteArrayList<>(); 273 274 try { 275 ClientConnection connection = 276 this.httpClient.connect(url, this.worker, this.bufferPool, this.optionMap).get(); 277 try { 278 ClientRequest request = new ClientRequest().setMethod(method).setPath(url.getPath()); 279 request.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.HOST), url.getHost()); 280 if (StringUtils.hasLength(body)) { 281 HttpString headerName = HttpString.tryFromString(HttpHeaders.CONTENT_LENGTH); 282 request.getRequestHeaders().add(headerName, body.length()); 283 } 284 addHttpHeaders(request, headers); 285 connection.sendRequest(request, createRequestCallback(body, responses, latch)); 286 287 latch.await(); 288 ClientResponse response = responses.iterator().next(); 289 HttpStatus status = HttpStatus.valueOf(response.getResponseCode()); 290 HttpHeaders responseHeaders = toHttpHeaders(response.getResponseHeaders()); 291 String responseBody = response.getAttachment(RESPONSE_BODY); 292 return (responseBody != null ? 293 new ResponseEntity<>(responseBody, responseHeaders, status) : 294 new ResponseEntity<>(responseHeaders, status)); 295 } 296 finally { 297 IoUtils.safeClose(connection); 298 } 299 } 300 catch (IOException ex) { 301 throw new SockJsTransportFailureException("Failed to execute request to " + url, ex); 302 } 303 catch (InterruptedException ex) { 304 Thread.currentThread().interrupt(); 305 throw new SockJsTransportFailureException("Interrupted while processing request to " + url, ex); 306 } 307 } 308 309 private ClientCallback<ClientExchange> createRequestCallback(final @Nullable String body, 310 final List<ClientResponse> responses, final CountDownLatch latch) { 311 312 return new ClientCallback<ClientExchange>() { 313 @Override 314 public void completed(ClientExchange result) { 315 result.setResponseListener(new ClientCallback<ClientExchange>() { 316 @Override 317 public void completed(final ClientExchange result) { 318 responses.add(result.getResponse()); 319 new StringReadChannelListener(result.getConnection().getBufferPool()) { 320 @Override 321 protected void stringDone(String string) { 322 result.getResponse().putAttachment(RESPONSE_BODY, string); 323 latch.countDown(); 324 } 325 @Override 326 protected void error(IOException ex) { 327 onFailure(latch, ex); 328 } 329 }.setup(result.getResponseChannel()); 330 } 331 @Override 332 public void failed(IOException ex) { 333 onFailure(latch, ex); 334 } 335 }); 336 try { 337 if (body != null) { 338 result.getRequestChannel().write(ByteBuffer.wrap(body.getBytes())); 339 } 340 result.getRequestChannel().shutdownWrites(); 341 if (!result.getRequestChannel().flush()) { 342 result.getRequestChannel().getWriteSetter() 343 .set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null)); 344 result.getRequestChannel().resumeWrites(); 345 } 346 } 347 catch (IOException ex) { 348 onFailure(latch, ex); 349 } 350 } 351 352 @Override 353 public void failed(IOException ex) { 354 onFailure(latch, ex); 355 } 356 357 private void onFailure(CountDownLatch latch, IOException ex) { 358 latch.countDown(); 359 throw new SockJsTransportFailureException("Failed to execute request", ex); 360 } 361 }; 362 } 363 364 365 private class SockJsResponseListener implements ChannelListener<StreamSourceChannel> { 366 367 private final TransportRequest request; 368 369 private final ClientConnection connection; 370 371 private final URI url; 372 373 private final HttpHeaders headers; 374 375 private final XhrClientSockJsSession session; 376 377 private final SettableListenableFuture<WebSocketSession> connectFuture; 378 379 private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); 380 381 public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url, 382 HttpHeaders headers, XhrClientSockJsSession sockJsSession, 383 SettableListenableFuture<WebSocketSession> connectFuture) { 384 385 this.request = request; 386 this.connection = connection; 387 this.url = url; 388 this.headers = headers; 389 this.session = sockJsSession; 390 this.connectFuture = connectFuture; 391 } 392 393 public void setup(StreamSourceChannel channel) { 394 channel.suspendReads(); 395 channel.getReadSetter().set(this); 396 channel.resumeReads(); 397 } 398 399 @Override 400 public void handleEvent(StreamSourceChannel channel) { 401 if (this.session.isDisconnected()) { 402 if (logger.isDebugEnabled()) { 403 logger.debug("SockJS sockJsSession closed, closing response."); 404 } 405 IoUtils.safeClose(this.connection); 406 throw new SockJsException("Session closed.", this.session.getId(), null); 407 } 408 409 try (PooledByteBuffer pooled = bufferPool.allocate()) { 410 int r; 411 do { 412 ByteBuffer buffer = pooled.getBuffer(); 413 buffer.clear(); 414 r = channel.read(buffer); 415 buffer.flip(); 416 if (r == 0) { 417 return; 418 } 419 else if (r == -1) { 420 onSuccess(); 421 } 422 else { 423 while (buffer.hasRemaining()) { 424 int b = buffer.get(); 425 if (b == '\n') { 426 handleFrame(); 427 } 428 else { 429 this.outputStream.write(b); 430 } 431 } 432 } 433 } 434 while (r > 0); 435 } 436 catch (IOException exc) { 437 onFailure(exc); 438 } 439 } 440 441 private void handleFrame() { 442 String content = StreamUtils.copyToString(this.outputStream, SockJsFrame.CHARSET); 443 this.outputStream.reset(); 444 if (logger.isTraceEnabled()) { 445 logger.trace("XHR content received: " + content); 446 } 447 if (!PRELUDE.equals(content)) { 448 this.session.handleFrame(content); 449 } 450 } 451 452 public void onSuccess() { 453 if (this.outputStream.size() > 0) { 454 handleFrame(); 455 } 456 if (logger.isTraceEnabled()) { 457 logger.trace("XHR receive request completed."); 458 } 459 IoUtils.safeClose(this.connection); 460 executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture); 461 } 462 463 public void onFailure(Throwable failure) { 464 IoUtils.safeClose(this.connection); 465 if (this.connectFuture.setException(failure)) { 466 return; 467 } 468 if (this.session.isDisconnected()) { 469 this.session.afterTransportClosed(null); 470 } 471 else { 472 this.session.handleTransportError(failure); 473 this.session.afterTransportClosed(new CloseStatus(1006, failure.getMessage())); 474 } 475 } 476 } 477 478}