001/*
002 * Copyright 2002-2020 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.ByteArrayOutputStream;
020import java.io.IOException;
021import java.net.URI;
022import java.nio.ByteBuffer;
023import java.util.List;
024import java.util.concurrent.CopyOnWriteArrayList;
025import java.util.concurrent.CountDownLatch;
026
027import io.undertow.client.ClientCallback;
028import io.undertow.client.ClientConnection;
029import io.undertow.client.ClientExchange;
030import io.undertow.client.ClientRequest;
031import io.undertow.client.ClientResponse;
032import io.undertow.client.UndertowClient;
033import io.undertow.connector.ByteBufferPool;
034import io.undertow.connector.PooledByteBuffer;
035import io.undertow.server.DefaultByteBufferPool;
036import io.undertow.util.AttachmentKey;
037import io.undertow.util.HeaderMap;
038import io.undertow.util.HttpString;
039import io.undertow.util.Methods;
040import io.undertow.util.StringReadChannelListener;
041import org.xnio.ChannelListener;
042import org.xnio.ChannelListeners;
043import org.xnio.IoUtils;
044import org.xnio.OptionMap;
045import org.xnio.Options;
046import org.xnio.Xnio;
047import org.xnio.XnioWorker;
048import org.xnio.channels.StreamSinkChannel;
049import org.xnio.channels.StreamSourceChannel;
050
051import org.springframework.http.HttpHeaders;
052import org.springframework.http.HttpStatus;
053import org.springframework.http.ResponseEntity;
054import org.springframework.lang.Nullable;
055import org.springframework.util.Assert;
056import org.springframework.util.StreamUtils;
057import org.springframework.util.StringUtils;
058import org.springframework.util.concurrent.SettableListenableFuture;
059import org.springframework.web.client.HttpServerErrorException;
060import org.springframework.web.socket.CloseStatus;
061import org.springframework.web.socket.TextMessage;
062import org.springframework.web.socket.WebSocketHandler;
063import org.springframework.web.socket.WebSocketSession;
064import org.springframework.web.socket.sockjs.SockJsException;
065import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
066import org.springframework.web.socket.sockjs.frame.SockJsFrame;
067
068/**
069 * An XHR transport based on Undertow's {@link io.undertow.client.UndertowClient}.
070 * Requires Undertow 1.3 or 1.4, including XNIO, as of Spring Framework 5.0.
071 *
072 * <p>When used for testing purposes (e.g. load testing) or for specific use cases
073 * (like HTTPS configuration), a custom OptionMap should be provided:
074 *
075 * <pre class="code">
076 * OptionMap optionMap = OptionMap.builder()
077 *   .set(Options.WORKER_IO_THREADS, 8)
078 *   .set(Options.TCP_NODELAY, true)
079 *   .set(Options.KEEP_ALIVE, true)
080 *   .set(Options.WORKER_NAME, "SockJSClient")
081 *   .getMap();
082 *
083 * UndertowXhrTransport transport = new UndertowXhrTransport(optionMap);
084 * </pre>
085 *
086 * @author Brian Clozel
087 * @author Rossen Stoyanchev
088 * @since 4.1.2
089 * @see org.xnio.Options
090 */
091public class UndertowXhrTransport extends AbstractXhrTransport {
092
093        private static final AttachmentKey<String> RESPONSE_BODY = AttachmentKey.create(String.class);
094
095
096        private final OptionMap optionMap;
097
098        private final UndertowClient httpClient;
099
100        private final XnioWorker worker;
101
102        private final ByteBufferPool bufferPool;
103
104
105        public UndertowXhrTransport() throws IOException {
106                this(OptionMap.builder().parse(Options.WORKER_NAME, "SockJSClient").getMap());
107        }
108
109        public UndertowXhrTransport(OptionMap optionMap) throws IOException {
110                Assert.notNull(optionMap, "OptionMap is required");
111                this.optionMap = optionMap;
112                this.httpClient = UndertowClient.getInstance();
113                this.worker = Xnio.getInstance().createWorker(optionMap);
114                this.bufferPool = new DefaultByteBufferPool(false, 1024, -1, 2);
115        }
116
117
118        /**
119         * Return Undertow's native HTTP client.
120         */
121        public UndertowClient getHttpClient() {
122                return this.httpClient;
123        }
124
125        /**
126         * Return the {@link org.xnio.XnioWorker} backing the I/O operations
127         * for Undertow's HTTP client.
128         * @see org.xnio.Xnio
129         */
130        public XnioWorker getWorker() {
131                return this.worker;
132        }
133
134
135        @Override
136        protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl,
137                        HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
138                        SettableListenableFuture<WebSocketSession> connectFuture) {
139
140                executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture);
141        }
142
143        private void executeReceiveRequest(final TransportRequest transportRequest,
144                        final URI url, final HttpHeaders headers, final XhrClientSockJsSession session,
145                        final SettableListenableFuture<WebSocketSession> connectFuture) {
146
147                if (logger.isTraceEnabled()) {
148                        logger.trace("Starting XHR receive request for " + url);
149                }
150
151                ClientCallback<ClientConnection> clientCallback = new ClientCallback<ClientConnection>() {
152                        @Override
153                        public void completed(ClientConnection connection) {
154                                ClientRequest request = new ClientRequest().setMethod(Methods.POST).setPath(url.getPath());
155                                HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST);
156                                request.getRequestHeaders().add(headerName, url.getHost());
157                                addHttpHeaders(request, headers);
158                                HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
159                                connection.sendRequest(request, createReceiveCallback(transportRequest,
160                                                url, httpHeaders, session, connectFuture));
161                        }
162
163                        @Override
164                        public void failed(IOException ex) {
165                                throw new SockJsTransportFailureException("Failed to execute request to " + url, ex);
166                        }
167                };
168
169                this.httpClient.connect(clientCallback, url, this.worker, this.bufferPool, this.optionMap);
170        }
171
172        private static void addHttpHeaders(ClientRequest request, HttpHeaders headers) {
173                HeaderMap headerMap = request.getRequestHeaders();
174                headers.forEach((key, values) -> {
175                        for (String value : values) {
176                                headerMap.add(HttpString.tryFromString(key), value);
177                        }
178                });
179        }
180
181        private ClientCallback<ClientExchange> createReceiveCallback(final TransportRequest transportRequest,
182                        final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession,
183                        final SettableListenableFuture<WebSocketSession> connectFuture) {
184
185                return new ClientCallback<ClientExchange>() {
186                        @Override
187                        public void completed(final ClientExchange exchange) {
188                                exchange.setResponseListener(new ClientCallback<ClientExchange>() {
189                                        @Override
190                                        public void completed(ClientExchange result) {
191                                                ClientResponse response = result.getResponse();
192                                                if (response.getResponseCode() != 200) {
193                                                        HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
194                                                        IoUtils.safeClose(result.getConnection());
195                                                        onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status"));
196                                                }
197                                                else {
198                                                        SockJsResponseListener listener = new SockJsResponseListener(
199                                                                        transportRequest, result.getConnection(), url, headers,
200                                                                        sockJsSession, connectFuture);
201                                                        listener.setup(result.getResponseChannel());
202                                                }
203                                                if (logger.isTraceEnabled()) {
204                                                        logger.trace("XHR receive headers: " + toHttpHeaders(response.getResponseHeaders()));
205                                                }
206                                                try {
207                                                        StreamSinkChannel channel = result.getRequestChannel();
208                                                        channel.shutdownWrites();
209                                                        if (!channel.flush()) {
210                                                                channel.getWriteSetter().set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null));
211                                                                channel.resumeWrites();
212                                                        }
213                                                }
214                                                catch (IOException exc) {
215                                                        IoUtils.safeClose(result.getConnection());
216                                                        onFailure(exc);
217                                                }
218                                        }
219
220                                        @Override
221                                        public void failed(IOException exc) {
222                                                IoUtils.safeClose(exchange.getConnection());
223                                                onFailure(exc);
224                                        }
225                                });
226                        }
227
228                        @Override
229                        public void failed(IOException exc) {
230                                onFailure(exc);
231                        }
232
233                        private void onFailure(Throwable failure) {
234                                if (connectFuture.setException(failure)) {
235                                        return;
236                                }
237                                if (sockJsSession.isDisconnected()) {
238                                        sockJsSession.afterTransportClosed(null);
239                                }
240                                else {
241                                        sockJsSession.handleTransportError(failure);
242                                        sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
243                                }
244                        }
245                };
246        }
247
248        private static HttpHeaders toHttpHeaders(HeaderMap headerMap) {
249                HttpHeaders httpHeaders = new HttpHeaders();
250                for (HttpString name : headerMap.getHeaderNames()) {
251                        for (String value : headerMap.get(name)) {
252                                httpHeaders.add(name.toString(), value);
253                        }
254                }
255                return httpHeaders;
256        }
257
258        @Override
259        protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
260                return executeRequest(infoUrl, Methods.GET, headers, null);
261        }
262
263        @Override
264        protected ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
265                return executeRequest(url, Methods.POST, headers, message.getPayload());
266        }
267
268        protected ResponseEntity<String> executeRequest(
269                        URI url, HttpString method, HttpHeaders headers, @Nullable String body) {
270
271                CountDownLatch latch = new CountDownLatch(1);
272                List<ClientResponse> responses = new CopyOnWriteArrayList<>();
273
274                try {
275                        ClientConnection connection =
276                                        this.httpClient.connect(url, this.worker, this.bufferPool, this.optionMap).get();
277                        try {
278                                ClientRequest request = new ClientRequest().setMethod(method).setPath(url.getPath());
279                                request.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.HOST), url.getHost());
280                                if (StringUtils.hasLength(body)) {
281                                        HttpString headerName = HttpString.tryFromString(HttpHeaders.CONTENT_LENGTH);
282                                        request.getRequestHeaders().add(headerName, body.length());
283                                }
284                                addHttpHeaders(request, headers);
285                                connection.sendRequest(request, createRequestCallback(body, responses, latch));
286
287                                latch.await();
288                                ClientResponse response = responses.iterator().next();
289                                HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
290                                HttpHeaders responseHeaders = toHttpHeaders(response.getResponseHeaders());
291                                String responseBody = response.getAttachment(RESPONSE_BODY);
292                                return (responseBody != null ?
293                                                new ResponseEntity<>(responseBody, responseHeaders, status) :
294                                                new ResponseEntity<>(responseHeaders, status));
295                        }
296                        finally {
297                                IoUtils.safeClose(connection);
298                        }
299                }
300                catch (IOException ex) {
301                        throw new SockJsTransportFailureException("Failed to execute request to " + url, ex);
302                }
303                catch (InterruptedException ex) {
304                        Thread.currentThread().interrupt();
305                        throw new SockJsTransportFailureException("Interrupted while processing request to " + url, ex);
306                }
307        }
308
309        private ClientCallback<ClientExchange> createRequestCallback(final @Nullable String body,
310                        final List<ClientResponse> responses, final CountDownLatch latch) {
311
312                return new ClientCallback<ClientExchange>() {
313                        @Override
314                        public void completed(ClientExchange result) {
315                                result.setResponseListener(new ClientCallback<ClientExchange>() {
316                                        @Override
317                                        public void completed(final ClientExchange result) {
318                                                responses.add(result.getResponse());
319                                                new StringReadChannelListener(result.getConnection().getBufferPool()) {
320                                                        @Override
321                                                        protected void stringDone(String string) {
322                                                                result.getResponse().putAttachment(RESPONSE_BODY, string);
323                                                                latch.countDown();
324                                                        }
325                                                        @Override
326                                                        protected void error(IOException ex) {
327                                                                onFailure(latch, ex);
328                                                        }
329                                                }.setup(result.getResponseChannel());
330                                        }
331                                        @Override
332                                        public void failed(IOException ex) {
333                                                onFailure(latch, ex);
334                                        }
335                                });
336                                try {
337                                        if (body != null) {
338                                                result.getRequestChannel().write(ByteBuffer.wrap(body.getBytes()));
339                                        }
340                                        result.getRequestChannel().shutdownWrites();
341                                        if (!result.getRequestChannel().flush()) {
342                                                result.getRequestChannel().getWriteSetter()
343                                                                .set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null));
344                                                result.getRequestChannel().resumeWrites();
345                                        }
346                                }
347                                catch (IOException ex) {
348                                        onFailure(latch, ex);
349                                }
350                        }
351
352                        @Override
353                        public void failed(IOException ex) {
354                                onFailure(latch, ex);
355                        }
356
357                        private void onFailure(CountDownLatch latch, IOException ex) {
358                                latch.countDown();
359                                throw new SockJsTransportFailureException("Failed to execute request", ex);
360                        }
361                };
362        }
363
364
365        private class SockJsResponseListener implements ChannelListener<StreamSourceChannel> {
366
367                private final TransportRequest request;
368
369                private final ClientConnection connection;
370
371                private final URI url;
372
373                private final HttpHeaders headers;
374
375                private final XhrClientSockJsSession session;
376
377                private final SettableListenableFuture<WebSocketSession> connectFuture;
378
379                private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
380
381                public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url,
382                                HttpHeaders headers, XhrClientSockJsSession sockJsSession,
383                                SettableListenableFuture<WebSocketSession> connectFuture) {
384
385                        this.request = request;
386                        this.connection = connection;
387                        this.url = url;
388                        this.headers = headers;
389                        this.session = sockJsSession;
390                        this.connectFuture = connectFuture;
391                }
392
393                public void setup(StreamSourceChannel channel) {
394                        channel.suspendReads();
395                        channel.getReadSetter().set(this);
396                        channel.resumeReads();
397                }
398
399                @Override
400                public void handleEvent(StreamSourceChannel channel) {
401                        if (this.session.isDisconnected()) {
402                                if (logger.isDebugEnabled()) {
403                                        logger.debug("SockJS sockJsSession closed, closing response.");
404                                }
405                                IoUtils.safeClose(this.connection);
406                                throw new SockJsException("Session closed.", this.session.getId(), null);
407                        }
408
409                        try (PooledByteBuffer pooled = bufferPool.allocate()) {
410                                int r;
411                                do {
412                                        ByteBuffer buffer = pooled.getBuffer();
413                                        buffer.clear();
414                                        r = channel.read(buffer);
415                                        buffer.flip();
416                                        if (r == 0) {
417                                                return;
418                                        }
419                                        else if (r == -1) {
420                                                onSuccess();
421                                        }
422                                        else {
423                                                while (buffer.hasRemaining()) {
424                                                        int b = buffer.get();
425                                                        if (b == '\n') {
426                                                                handleFrame();
427                                                        }
428                                                        else {
429                                                                this.outputStream.write(b);
430                                                        }
431                                                }
432                                        }
433                                }
434                                while (r > 0);
435                        }
436                        catch (IOException exc) {
437                                onFailure(exc);
438                        }
439                }
440
441                private void handleFrame() {
442                        String content = StreamUtils.copyToString(this.outputStream, SockJsFrame.CHARSET);
443                        this.outputStream.reset();
444                        if (logger.isTraceEnabled()) {
445                                logger.trace("XHR content received: " + content);
446                        }
447                        if (!PRELUDE.equals(content)) {
448                                this.session.handleFrame(content);
449                        }
450                }
451
452                public void onSuccess() {
453                        if (this.outputStream.size() > 0) {
454                                handleFrame();
455                        }
456                        if (logger.isTraceEnabled()) {
457                                logger.trace("XHR receive request completed.");
458                        }
459                        IoUtils.safeClose(this.connection);
460                        executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture);
461                }
462
463                public void onFailure(Throwable failure) {
464                        IoUtils.safeClose(this.connection);
465                        if (this.connectFuture.setException(failure)) {
466                                return;
467                        }
468                        if (this.session.isDisconnected()) {
469                                this.session.afterTransportClosed(null);
470                        }
471                        else {
472                                this.session.handleTransportError(failure);
473                                this.session.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
474                        }
475                }
476        }
477
478}