001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.ByteArrayOutputStream;
020import java.io.IOException;
021import java.lang.reflect.Method;
022import java.net.URI;
023import java.nio.ByteBuffer;
024import java.util.List;
025import java.util.concurrent.CopyOnWriteArrayList;
026import java.util.concurrent.CountDownLatch;
027
028import io.undertow.client.ClientCallback;
029import io.undertow.client.ClientConnection;
030import io.undertow.client.ClientExchange;
031import io.undertow.client.ClientRequest;
032import io.undertow.client.ClientResponse;
033import io.undertow.client.UndertowClient;
034import io.undertow.connector.ByteBufferPool;
035import io.undertow.connector.PooledByteBuffer;
036import io.undertow.server.DefaultByteBufferPool;
037import io.undertow.util.AttachmentKey;
038import io.undertow.util.HeaderMap;
039import io.undertow.util.HttpString;
040import io.undertow.util.Methods;
041import io.undertow.util.StringReadChannelListener;
042import org.xnio.ChannelListener;
043import org.xnio.ChannelListeners;
044import org.xnio.IoFuture;
045import org.xnio.IoUtils;
046import org.xnio.OptionMap;
047import org.xnio.Options;
048import org.xnio.Pool;
049import org.xnio.Xnio;
050import org.xnio.XnioWorker;
051import org.xnio.channels.StreamSinkChannel;
052import org.xnio.channels.StreamSourceChannel;
053
054import org.springframework.http.HttpHeaders;
055import org.springframework.http.HttpStatus;
056import org.springframework.http.ResponseEntity;
057import org.springframework.util.Assert;
058import org.springframework.util.ClassUtils;
059import org.springframework.util.ReflectionUtils;
060import org.springframework.util.StringUtils;
061import org.springframework.util.concurrent.SettableListenableFuture;
062import org.springframework.web.client.HttpServerErrorException;
063import org.springframework.web.socket.CloseStatus;
064import org.springframework.web.socket.TextMessage;
065import org.springframework.web.socket.WebSocketHandler;
066import org.springframework.web.socket.WebSocketSession;
067import org.springframework.web.socket.sockjs.SockJsException;
068import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
069import org.springframework.web.socket.sockjs.frame.SockJsFrame;
070
071/**
072 * An XHR transport based on Undertow's {@link io.undertow.client.UndertowClient}.
073 * Compatible with Undertow 1.0 to 1.3, as of Spring Framework 4.2.2.
074 *
075 * <p>When used for testing purposes (e.g. load testing) or for specific use cases
076 * (like HTTPS configuration), a custom OptionMap should be provided:
077 *
078 * <pre class="code">
079 * OptionMap optionMap = OptionMap.builder()
080 *   .set(Options.WORKER_IO_THREADS, 8)
081 *   .set(Options.TCP_NODELAY, true)
082 *   .set(Options.KEEP_ALIVE, true)
083 *   .set(Options.WORKER_NAME, "SockJSClient")
084 *   .getMap();
085 *
086 * UndertowXhrTransport transport = new UndertowXhrTransport(optionMap);
087 * </pre>
088 *
089 * @author Brian Clozel
090 * @author Rossen Stoyanchev
091 * @since 4.1.2
092 * @see org.xnio.Options
093 */
094public class UndertowXhrTransport extends AbstractXhrTransport {
095
096        private static final AttachmentKey<String> RESPONSE_BODY = AttachmentKey.create(String.class);
097
098        private static final boolean undertow13Present = ClassUtils.isPresent(
099                        "io.undertow.connector.ByteBufferPool", UndertowXhrTransport.class.getClassLoader());
100
101
102        private final OptionMap optionMap;
103
104        private final UndertowClient httpClient;
105
106        private final XnioWorker worker;
107
108        private final UndertowBufferSupport undertowBufferSupport;
109
110
111        public UndertowXhrTransport() throws IOException {
112                this(OptionMap.builder().parse(Options.WORKER_NAME, "SockJSClient").getMap());
113        }
114
115        public UndertowXhrTransport(OptionMap optionMap) throws IOException {
116                Assert.notNull(optionMap, "OptionMap is required");
117                this.optionMap = optionMap;
118                this.httpClient = UndertowClient.getInstance();
119                this.worker = Xnio.getInstance().createWorker(optionMap);
120                this.undertowBufferSupport =
121                                (undertow13Present ? new Undertow13BufferSupport() : new UndertowXnioBufferSupport());
122        }
123
124
125        /**
126         * Return Undertow's native HTTP client
127         */
128        public UndertowClient getHttpClient() {
129                return this.httpClient;
130        }
131
132        /**
133         * Return the {@link org.xnio.XnioWorker} backing the I/O operations
134         * for Undertow's HTTP client.
135         * @see org.xnio.Xnio
136         */
137        public XnioWorker getWorker() {
138                return this.worker;
139        }
140
141
142        @Override
143        protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl,
144                        HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
145                        SettableListenableFuture<WebSocketSession> connectFuture) {
146
147                executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture);
148        }
149
150        private void executeReceiveRequest(final TransportRequest transportRequest,
151                        final URI url, final HttpHeaders headers, final XhrClientSockJsSession session,
152                        final SettableListenableFuture<WebSocketSession> connectFuture) {
153
154                if (logger.isTraceEnabled()) {
155                        logger.trace("Starting XHR receive request for " + url);
156                }
157
158                ClientCallback<ClientConnection> clientCallback = new ClientCallback<ClientConnection>() {
159                        @Override
160                        public void completed(ClientConnection connection) {
161                                ClientRequest request = new ClientRequest().setMethod(Methods.POST).setPath(url.getPath());
162                                HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST);
163                                request.getRequestHeaders().add(headerName, url.getHost());
164                                addHttpHeaders(request, headers);
165                                HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
166                                connection.sendRequest(request, createReceiveCallback(transportRequest,
167                                                url, httpHeaders, session, connectFuture));
168                        }
169
170                        @Override
171                        public void failed(IOException ex) {
172                                throw new SockJsTransportFailureException("Failed to execute request to " + url, ex);
173                        }
174                };
175
176                this.undertowBufferSupport.httpClientConnect(this.httpClient, clientCallback, url, worker, this.optionMap);
177        }
178
179        private static void addHttpHeaders(ClientRequest request, HttpHeaders headers) {
180                HeaderMap headerMap = request.getRequestHeaders();
181                for (String name : headers.keySet()) {
182                        for (String value : headers.get(name)) {
183                                headerMap.add(HttpString.tryFromString(name), value);
184                        }
185                }
186        }
187
188        private ClientCallback<ClientExchange> createReceiveCallback(final TransportRequest transportRequest,
189                        final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession,
190                        final SettableListenableFuture<WebSocketSession> connectFuture) {
191
192                return new ClientCallback<ClientExchange>() {
193                        @Override
194                        public void completed(final ClientExchange exchange) {
195                                exchange.setResponseListener(new ClientCallback<ClientExchange>() {
196                                        @Override
197                                        public void completed(ClientExchange result) {
198                                                ClientResponse response = result.getResponse();
199                                                if (response.getResponseCode() != 200) {
200                                                        HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
201                                                        IoUtils.safeClose(result.getConnection());
202                                                        onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status"));
203                                                }
204                                                else {
205                                                        SockJsResponseListener listener = new SockJsResponseListener(
206                                                                        transportRequest, result.getConnection(), url, headers,
207                                                                        sockJsSession, connectFuture);
208                                                        listener.setup(result.getResponseChannel());
209                                                }
210                                                if (logger.isTraceEnabled()) {
211                                                        logger.trace("XHR receive headers: " + toHttpHeaders(response.getResponseHeaders()));
212                                                }
213                                                try {
214                                                        StreamSinkChannel channel = result.getRequestChannel();
215                                                        channel.shutdownWrites();
216                                                        if (!channel.flush()) {
217                                                                channel.getWriteSetter().set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null));
218                                                                channel.resumeWrites();
219                                                        }
220                                                }
221                                                catch (IOException exc) {
222                                                        IoUtils.safeClose(result.getConnection());
223                                                        onFailure(exc);
224                                                }
225                                        }
226
227                                        @Override
228                                        public void failed(IOException exc) {
229                                                IoUtils.safeClose(exchange.getConnection());
230                                                onFailure(exc);
231                                        }
232                                });
233                        }
234
235                        @Override
236                        public void failed(IOException exc) {
237                                onFailure(exc);
238                        }
239
240                        private void onFailure(Throwable failure) {
241                                if (connectFuture.setException(failure)) {
242                                        return;
243                                }
244                                if (sockJsSession.isDisconnected()) {
245                                        sockJsSession.afterTransportClosed(null);
246                                }
247                                else {
248                                        sockJsSession.handleTransportError(failure);
249                                        sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
250                                }
251                        }
252                };
253        }
254
255        private static HttpHeaders toHttpHeaders(HeaderMap headerMap) {
256                HttpHeaders httpHeaders = new HttpHeaders();
257                for (HttpString name : headerMap.getHeaderNames()) {
258                        for (String value : headerMap.get(name)) {
259                                httpHeaders.add(name.toString(), value);
260                        }
261                }
262                return httpHeaders;
263        }
264
265        @Override
266        protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
267                return executeRequest(infoUrl, Methods.GET, headers, null);
268        }
269
270        @Override
271        protected ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
272                return executeRequest(url, Methods.POST, headers, message.getPayload());
273        }
274
275        protected ResponseEntity<String> executeRequest(URI url, HttpString method, HttpHeaders headers, String body) {
276                CountDownLatch latch = new CountDownLatch(1);
277                List<ClientResponse> responses = new CopyOnWriteArrayList<ClientResponse>();
278
279                try {
280                        ClientConnection connection = this.undertowBufferSupport
281                                        .httpClientConnect(this.httpClient, url, this.worker, this.optionMap).get();
282                        try {
283                                ClientRequest request = new ClientRequest().setMethod(method).setPath(url.getPath());
284                                request.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.HOST), url.getHost());
285                                if (StringUtils.hasLength(body)) {
286                                        HttpString headerName = HttpString.tryFromString(HttpHeaders.CONTENT_LENGTH);
287                                        request.getRequestHeaders().add(headerName, body.length());
288                                }
289                                addHttpHeaders(request, headers);
290                                connection.sendRequest(request, createRequestCallback(body, responses, latch));
291
292                                latch.await();
293                                ClientResponse response = responses.iterator().next();
294                                HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
295                                HttpHeaders responseHeaders = toHttpHeaders(response.getResponseHeaders());
296                                String responseBody = response.getAttachment(RESPONSE_BODY);
297                                return (responseBody != null ?
298                                                new ResponseEntity<String>(responseBody, responseHeaders, status) :
299                                                new ResponseEntity<String>(responseHeaders, status));
300                        }
301                        finally {
302                                IoUtils.safeClose(connection);
303                        }
304                }
305                catch (IOException ex) {
306                        throw new SockJsTransportFailureException("Failed to execute request to " + url, ex);
307                }
308                catch (InterruptedException ex) {
309                        Thread.currentThread().interrupt();
310                        throw new SockJsTransportFailureException("Interrupted while processing request to " + url, ex);
311                }
312        }
313
314        private ClientCallback<ClientExchange> createRequestCallback(final String body,
315                        final List<ClientResponse> responses, final CountDownLatch latch) {
316
317                return new ClientCallback<ClientExchange>() {
318                        @Override
319                        public void completed(ClientExchange result) {
320                                result.setResponseListener(new ClientCallback<ClientExchange>() {
321                                        @Override
322                                        @SuppressWarnings("deprecation")
323                                        public void completed(final ClientExchange result) {
324                                                responses.add(result.getResponse());
325                                                new StringReadChannelListener(result.getConnection().getBufferPool()) {
326                                                        @Override
327                                                        protected void stringDone(String string) {
328                                                                result.getResponse().putAttachment(RESPONSE_BODY, string);
329                                                                latch.countDown();
330                                                        }
331                                                        @Override
332                                                        protected void error(IOException ex) {
333                                                                onFailure(latch, ex);
334                                                        }
335                                                }.setup(result.getResponseChannel());
336                                        }
337                                        @Override
338                                        public void failed(IOException ex) {
339                                                onFailure(latch, ex);
340                                        }
341                                });
342                                try {
343                                        if (body != null) {
344                                                result.getRequestChannel().write(ByteBuffer.wrap(body.getBytes()));
345                                        }
346                                        result.getRequestChannel().shutdownWrites();
347                                        if (!result.getRequestChannel().flush()) {
348                                                result.getRequestChannel().getWriteSetter()
349                                                                .set(ChannelListeners.<StreamSinkChannel>flushingChannelListener(null, null));
350                                                result.getRequestChannel().resumeWrites();
351                                        }
352                                }
353                                catch (IOException ex) {
354                                        onFailure(latch, ex);
355                                }
356                        }
357
358                        @Override
359                        public void failed(IOException ex) {
360                                onFailure(latch, ex);
361                        }
362
363                        private void onFailure(CountDownLatch latch, IOException ex) {
364                                latch.countDown();
365                                throw new SockJsTransportFailureException("Failed to execute request", ex);
366                        }
367                };
368        }
369
370
371        private class SockJsResponseListener implements ChannelListener<StreamSourceChannel> {
372
373                private final TransportRequest request;
374
375                private final ClientConnection connection;
376
377                private final URI url;
378
379                private final HttpHeaders headers;
380
381                private final XhrClientSockJsSession session;
382
383                private final SettableListenableFuture<WebSocketSession> connectFuture;
384
385                private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
386
387                public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url,
388                                HttpHeaders headers, XhrClientSockJsSession sockJsSession,
389                                SettableListenableFuture<WebSocketSession> connectFuture) {
390
391                        this.request = request;
392                        this.connection = connection;
393                        this.url = url;
394                        this.headers = headers;
395                        this.session = sockJsSession;
396                        this.connectFuture = connectFuture;
397                }
398
399                public void setup(StreamSourceChannel channel) {
400                        channel.suspendReads();
401                        channel.getReadSetter().set(this);
402                        channel.resumeReads();
403                }
404
405                @Override
406                public void handleEvent(StreamSourceChannel channel) {
407                        if (this.session.isDisconnected()) {
408                                if (logger.isDebugEnabled()) {
409                                        logger.debug("SockJS sockJsSession closed, closing response.");
410                                }
411                                IoUtils.safeClose(this.connection);
412                                throw new SockJsException("Session closed.", this.session.getId(), null);
413                        }
414
415                        Object pooled = undertowBufferSupport.allocatePooledResource();
416                        try {
417                                int r;
418                                do {
419                                        ByteBuffer buffer = undertowBufferSupport.getByteBuffer(pooled);
420                                        buffer.clear();
421                                        r = channel.read(buffer);
422                                        buffer.flip();
423                                        if (r == 0) {
424                                                return;
425                                        }
426                                        else if (r == -1) {
427                                                onSuccess();
428                                        }
429                                        else {
430                                                while (buffer.hasRemaining()) {
431                                                        int b = buffer.get();
432                                                        if (b == '\n') {
433                                                                handleFrame();
434                                                        }
435                                                        else {
436                                                                this.outputStream.write(b);
437                                                        }
438                                                }
439                                        }
440                                }
441                                while (r > 0);
442                        }
443                        catch (IOException exc) {
444                                onFailure(exc);
445                        }
446                        finally {
447                                undertowBufferSupport.closePooledResource(pooled);
448                        }
449                }
450
451                private void handleFrame() {
452                        byte[] bytes = this.outputStream.toByteArray();
453                        this.outputStream.reset();
454                        String content = new String(bytes, SockJsFrame.CHARSET);
455                        if (logger.isTraceEnabled()) {
456                                logger.trace("XHR content received: " + content);
457                        }
458                        if (!PRELUDE.equals(content)) {
459                                this.session.handleFrame(new String(bytes, SockJsFrame.CHARSET));
460                        }
461                }
462
463                public void onSuccess() {
464                        if (this.outputStream.size() > 0) {
465                                handleFrame();
466                        }
467                        if (logger.isTraceEnabled()) {
468                                logger.trace("XHR receive request completed.");
469                        }
470                        IoUtils.safeClose(this.connection);
471                        executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture);
472                }
473
474                public void onFailure(Throwable failure) {
475                        IoUtils.safeClose(this.connection);
476                        if (this.connectFuture.setException(failure)) {
477                                return;
478                        }
479                        if (this.session.isDisconnected()) {
480                                this.session.afterTransportClosed(null);
481                        }
482                        else {
483                                this.session.handleTransportError(failure);
484                                this.session.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
485                        }
486                }
487        }
488
489
490        private interface UndertowBufferSupport {
491
492                Object allocatePooledResource();
493
494                ByteBuffer getByteBuffer(Object pooled);
495
496                void closePooledResource(Object pooled);
497
498                void httpClientConnect(UndertowClient httpClient, final ClientCallback<ClientConnection> listener,
499                                final URI uri, final XnioWorker worker, OptionMap options);
500
501                IoFuture<ClientConnection> httpClientConnect(UndertowClient httpClient, final URI uri,
502                                final XnioWorker worker, OptionMap options);
503        }
504
505
506        private class UndertowXnioBufferSupport implements UndertowBufferSupport {
507
508                private final org.xnio.Pool<ByteBuffer> xnioBufferPool;
509
510                private final Method httpClientConnectCallbackMethod;
511
512                private final Method httpClientConnectMethod;
513
514                public UndertowXnioBufferSupport() {
515                        this.xnioBufferPool = new org.xnio.ByteBufferSlicePool(1048, 1048);
516                        this.httpClientConnectCallbackMethod = ReflectionUtils.findMethod(UndertowClient.class, "connect",
517                                        ClientCallback.class, URI.class, XnioWorker.class, Pool.class, OptionMap.class);
518                        this.httpClientConnectMethod = ReflectionUtils.findMethod(UndertowClient.class, "connect",
519                                        URI.class, XnioWorker.class, Pool.class, OptionMap.class);
520                }
521
522                @Override
523                public Object allocatePooledResource() {
524                        return this.xnioBufferPool.allocate();
525                }
526
527                @Override
528                @SuppressWarnings("unchecked")
529                public ByteBuffer getByteBuffer(Object pooled) {
530                        return ((org.xnio.Pooled<ByteBuffer>) pooled).getResource();
531                }
532
533                @Override
534                @SuppressWarnings("unchecked")
535                public void closePooledResource(Object pooled) {
536                        ((org.xnio.Pooled<ByteBuffer>) pooled).close();
537                }
538
539                @Override
540                public void httpClientConnect(UndertowClient httpClient, ClientCallback<ClientConnection> listener, URI uri,
541                                XnioWorker worker, OptionMap options) {
542                        ReflectionUtils.invokeMethod(httpClientConnectCallbackMethod, httpClient, listener, uri, worker,
543                                        this.xnioBufferPool, options);
544                }
545
546                @Override
547                @SuppressWarnings("unchecked")
548                public IoFuture<ClientConnection> httpClientConnect(UndertowClient httpClient, URI uri,
549                                XnioWorker worker, OptionMap options) {
550                        return (IoFuture<ClientConnection>) ReflectionUtils.invokeMethod(httpClientConnectMethod, httpClient, uri,
551                                        worker, this.xnioBufferPool, options);
552                }
553        }
554
555
556        private class Undertow13BufferSupport implements UndertowBufferSupport {
557
558                private final ByteBufferPool undertowBufferPool;
559
560                private final Method httpClientConnectCallbackMethod;
561
562                private final Method httpClientConnectMethod;
563
564                public Undertow13BufferSupport() {
565                        this.undertowBufferPool = new DefaultByteBufferPool(false, 1024, -1, 2);
566                        this.httpClientConnectCallbackMethod = ReflectionUtils.findMethod(UndertowClient.class, "connect",
567                                        ClientCallback.class, URI.class, XnioWorker.class, ByteBufferPool.class, OptionMap.class);
568                        this.httpClientConnectMethod = ReflectionUtils.findMethod(UndertowClient.class, "connect",
569                                        URI.class, XnioWorker.class, ByteBufferPool.class, OptionMap.class);
570                }
571
572                @Override
573                public Object allocatePooledResource() {
574                        return this.undertowBufferPool.allocate();
575                }
576
577                @Override
578                public ByteBuffer getByteBuffer(Object pooled) {
579                        return ((PooledByteBuffer) pooled).getBuffer();
580                }
581
582                @Override
583                public void closePooledResource(Object pooled) {
584                        ((PooledByteBuffer) pooled).close();
585                }
586
587                @Override
588                public void httpClientConnect(UndertowClient httpClient, ClientCallback<ClientConnection> listener, URI uri,
589                                XnioWorker worker, OptionMap options) {
590                        ReflectionUtils.invokeMethod(httpClientConnectCallbackMethod, httpClient, listener, uri,
591                                        worker, this.undertowBufferPool, options);
592                }
593
594                @Override
595                @SuppressWarnings("unchecked")
596                public IoFuture<ClientConnection> httpClientConnect(UndertowClient httpClient, URI uri,
597                                XnioWorker worker, OptionMap options) {
598                        return (IoFuture<ClientConnection>) ReflectionUtils.invokeMethod(httpClientConnectMethod, httpClient, uri,
599                                        worker, this.undertowBufferPool, options);
600                }
601        }
602
603}