001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.devtools.tunnel.server;
018
019import java.io.IOException;
020import java.net.ConnectException;
021import java.nio.ByteBuffer;
022import java.nio.channels.ByteChannel;
023import java.util.ArrayDeque;
024import java.util.Deque;
025import java.util.Iterator;
026import java.util.concurrent.TimeUnit;
027import java.util.concurrent.atomic.AtomicLong;
028
029import org.apache.commons.logging.Log;
030import org.apache.commons.logging.LogFactory;
031
032import org.springframework.boot.devtools.tunnel.payload.HttpTunnelPayload;
033import org.springframework.boot.devtools.tunnel.payload.HttpTunnelPayloadForwarder;
034import org.springframework.http.HttpStatus;
035import org.springframework.http.MediaType;
036import org.springframework.http.server.ServerHttpAsyncRequestControl;
037import org.springframework.http.server.ServerHttpRequest;
038import org.springframework.http.server.ServerHttpResponse;
039import org.springframework.util.Assert;
040
041/**
042 * A server that can be used to tunnel TCP traffic over HTTP. Similar in design to the
043 * <a href="http://xmpp.org/extensions/xep-0124.html">Bidirectional-streams Over
044 * Synchronous HTTP (BOSH)</a> XMPP extension protocol, the server uses long polling with
045 * HTTP requests held open until a response is available. A typical traffic pattern would
046 * be as follows:
047 *
048 * <pre>
049 * [ CLIENT ]                      [ SERVER ]
050 *     | (a) Initial empty request     |
051 *     |------------------------------&gt;|
052 *     | (b) Data I                    |
053 *  --&gt;|------------------------------&gt;|---&gt;
054 *     | Response I (a)                |
055 *  &lt;--|&lt;------------------------------|&lt;---
056 *     |                               |
057 *     | (c) Data II                   |
058 *  --&gt;|------------------------------&gt;|---&gt;
059 *     | Response II (b)               |
060 *  &lt;--|&lt;------------------------------|&lt;---
061 *     .                               .
062 *     .                               .
063 * </pre>
064 *
065 * Each incoming request is held open to be used to carry the next available response. The
066 * server will hold at most two connections open at any given time.
067 * <p>
068 * Requests should be made using HTTP GET or POST (depending if there is a payload), with
069 * any payload contained in the body. The following response codes can be returned from
070 * the server:
071 * <table>
072 * <caption>Response Codes</caption>
073 * <tr>
074 * <th>Status</th>
075 * <th>Meaning</th>
076 * </tr>
077 * <tr>
078 * <td>200 (OK)</td>
079 * <td>Data payload response.</td>
080 * </tr>
081 * <tr>
082 * <td>204 (No Content)</td>
083 * <td>The long poll has timed out and the client should start a new request.</td>
084 * </tr>
085 * <tr>
086 * <td>429 (Too many requests)</td>
087 * <td>There are already enough connections open, this one can be dropped.</td>
088 * </tr>
089 * <tr>
090 * <td>410 (Gone)</td>
091 * <td>The target server has disconnected.</td>
092 * </tr>
093 * <tr>
094 * <td>503 (Service Unavailable)</td>
095 * <td>The target server is unavailable</td>
096 * </tr>
097 * </table>
098 * <p>
099 * Requests and responses that contain payloads include a {@code x-seq} header that
100 * contains a running sequence number (used to ensure data is applied in the correct
101 * order). The first request containing a payload should have a {@code x-seq} value of
102 * {@code 1}.
103 *
104 * @author Phillip Webb
105 * @author Andy Wilkinson
106 * @since 1.3.0
107 * @see org.springframework.boot.devtools.tunnel.client.HttpTunnelConnection
108 */
109public class HttpTunnelServer {
110
111        private static final long DEFAULT_LONG_POLL_TIMEOUT = TimeUnit.SECONDS.toMillis(10);
112
113        private static final long DEFAULT_DISCONNECT_TIMEOUT = TimeUnit.SECONDS.toMillis(30);
114
115        private static final MediaType DISCONNECT_MEDIA_TYPE = new MediaType("application",
116                        "x-disconnect");
117
118        private static final Log logger = LogFactory.getLog(HttpTunnelServer.class);
119
120        private final TargetServerConnection serverConnection;
121
122        private int longPollTimeout = (int) DEFAULT_LONG_POLL_TIMEOUT;
123
124        private long disconnectTimeout = DEFAULT_DISCONNECT_TIMEOUT;
125
126        private volatile ServerThread serverThread;
127
128        /**
129         * Creates a new {@link HttpTunnelServer} instance.
130         * @param serverConnection the connection to the target server
131         */
132        public HttpTunnelServer(TargetServerConnection serverConnection) {
133                Assert.notNull(serverConnection, "ServerConnection must not be null");
134                this.serverConnection = serverConnection;
135        }
136
137        /**
138         * Handle an incoming HTTP connection.
139         * @param request the HTTP request
140         * @param response the HTTP response
141         * @throws IOException in case of I/O errors
142         */
143        public void handle(ServerHttpRequest request, ServerHttpResponse response)
144                        throws IOException {
145                handle(new HttpConnection(request, response));
146        }
147
148        /**
149         * Handle an incoming HTTP connection.
150         * @param httpConnection the HTTP connection
151         * @throws IOException in case of I/O errors
152         */
153        protected void handle(HttpConnection httpConnection) throws IOException {
154                try {
155                        getServerThread().handleIncomingHttp(httpConnection);
156                        httpConnection.waitForResponse();
157                }
158                catch (ConnectException ex) {
159                        httpConnection.respond(HttpStatus.GONE);
160                }
161        }
162
163        /**
164         * Returns the active server thread, creating and starting it if necessary.
165         * @return the {@code ServerThread} (never {@code null})
166         * @throws IOException in case of I/O errors
167         */
168        protected ServerThread getServerThread() throws IOException {
169                synchronized (this) {
170                        if (this.serverThread == null) {
171                                ByteChannel channel = this.serverConnection.open(this.longPollTimeout);
172                                this.serverThread = new ServerThread(channel);
173                                this.serverThread.start();
174                        }
175                        return this.serverThread;
176                }
177        }
178
179        /**
180         * Called when the server thread exits.
181         */
182        void clearServerThread() {
183                synchronized (this) {
184                        this.serverThread = null;
185                }
186        }
187
188        /**
189         * Set the long poll timeout for the server.
190         * @param longPollTimeout the long poll timeout in milliseconds
191         */
192        public void setLongPollTimeout(int longPollTimeout) {
193                Assert.isTrue(longPollTimeout > 0, "LongPollTimeout must be a positive value");
194                this.longPollTimeout = longPollTimeout;
195        }
196
197        /**
198         * Set the maximum amount of time to wait for a client before closing the connection.
199         * @param disconnectTimeout the disconnect timeout in milliseconds
200         */
201        public void setDisconnectTimeout(long disconnectTimeout) {
202                Assert.isTrue(disconnectTimeout > 0,
203                                "DisconnectTimeout must be a positive value");
204                this.disconnectTimeout = disconnectTimeout;
205        }
206
207        /**
208         * The main server thread used to transfer tunnel traffic.
209         */
210        protected class ServerThread extends Thread {
211
212                private final ByteChannel targetServer;
213
214                private final Deque<HttpConnection> httpConnections;
215
216                private final HttpTunnelPayloadForwarder payloadForwarder;
217
218                private boolean closed;
219
220                private AtomicLong responseSeq = new AtomicLong();
221
222                private long lastHttpRequestTime;
223
224                public ServerThread(ByteChannel targetServer) {
225                        Assert.notNull(targetServer, "TargetServer must not be null");
226                        this.targetServer = targetServer;
227                        this.httpConnections = new ArrayDeque<>(2);
228                        this.payloadForwarder = new HttpTunnelPayloadForwarder(targetServer);
229                }
230
231                @Override
232                public void run() {
233                        try {
234                                try {
235                                        readAndForwardTargetServerData();
236                                }
237                                catch (Exception ex) {
238                                        logger.trace("Unexpected exception from tunnel server", ex);
239                                }
240                        }
241                        finally {
242                                this.closed = true;
243                                closeHttpConnections();
244                                closeTargetServer();
245                                HttpTunnelServer.this.clearServerThread();
246                        }
247                }
248
249                private void readAndForwardTargetServerData() throws IOException {
250                        while (this.targetServer.isOpen()) {
251                                closeStaleHttpConnections();
252                                ByteBuffer data = HttpTunnelPayload.getPayloadData(this.targetServer);
253                                synchronized (this.httpConnections) {
254                                        if (data != null) {
255                                                HttpTunnelPayload payload = new HttpTunnelPayload(
256                                                                this.responseSeq.incrementAndGet(), data);
257                                                payload.logIncoming();
258                                                HttpConnection connection = getOrWaitForHttpConnection();
259                                                connection.respond(payload);
260                                        }
261                                }
262                        }
263                }
264
265                private HttpConnection getOrWaitForHttpConnection() {
266                        synchronized (this.httpConnections) {
267                                HttpConnection httpConnection = this.httpConnections.pollFirst();
268                                while (httpConnection == null) {
269                                        try {
270                                                this.httpConnections.wait(HttpTunnelServer.this.longPollTimeout);
271                                        }
272                                        catch (InterruptedException ex) {
273                                                Thread.currentThread().interrupt();
274                                                closeHttpConnections();
275                                        }
276                                        httpConnection = this.httpConnections.pollFirst();
277                                }
278                                return httpConnection;
279                        }
280                }
281
282                private void closeStaleHttpConnections() throws IOException {
283                        synchronized (this.httpConnections) {
284                                checkNotDisconnected();
285                                Iterator<HttpConnection> iterator = this.httpConnections.iterator();
286                                while (iterator.hasNext()) {
287                                        HttpConnection httpConnection = iterator.next();
288                                        if (httpConnection
289                                                        .isOlderThan(HttpTunnelServer.this.longPollTimeout)) {
290                                                httpConnection.respond(HttpStatus.NO_CONTENT);
291                                                iterator.remove();
292                                        }
293                                }
294                        }
295                }
296
297                private void checkNotDisconnected() {
298                        if (this.lastHttpRequestTime > 0) {
299                                long timeout = HttpTunnelServer.this.disconnectTimeout;
300                                long duration = System.currentTimeMillis() - this.lastHttpRequestTime;
301                                Assert.state(duration < timeout,
302                                                () -> "Disconnect timeout: " + timeout + " " + duration);
303                        }
304                }
305
306                private void closeHttpConnections() {
307                        synchronized (this.httpConnections) {
308                                while (!this.httpConnections.isEmpty()) {
309                                        try {
310                                                this.httpConnections.removeFirst().respond(HttpStatus.GONE);
311                                        }
312                                        catch (Exception ex) {
313                                                logger.trace("Unable to close remote HTTP connection");
314                                        }
315                                }
316                        }
317                }
318
319                private void closeTargetServer() {
320                        try {
321                                this.targetServer.close();
322                        }
323                        catch (IOException ex) {
324                                logger.trace("Unable to target server connection");
325                        }
326                }
327
328                /**
329                 * Handle an incoming {@link HttpConnection}.
330                 * @param httpConnection the connection to handle.
331                 * @throws IOException in case of I/O errors
332                 */
333                public void handleIncomingHttp(HttpConnection httpConnection) throws IOException {
334                        if (this.closed) {
335                                httpConnection.respond(HttpStatus.GONE);
336                        }
337                        synchronized (this.httpConnections) {
338                                while (this.httpConnections.size() > 1) {
339                                        this.httpConnections.removeFirst()
340                                                        .respond(HttpStatus.TOO_MANY_REQUESTS);
341                                }
342                                this.lastHttpRequestTime = System.currentTimeMillis();
343                                this.httpConnections.addLast(httpConnection);
344                                this.httpConnections.notify();
345                        }
346                        forwardToTargetServer(httpConnection);
347                }
348
349                private void forwardToTargetServer(HttpConnection httpConnection)
350                                throws IOException {
351                        if (httpConnection.isDisconnectRequest()) {
352                                this.targetServer.close();
353                                interrupt();
354                        }
355                        ServerHttpRequest request = httpConnection.getRequest();
356                        HttpTunnelPayload payload = HttpTunnelPayload.get(request);
357                        if (payload != null) {
358                                this.payloadForwarder.forward(payload);
359                        }
360                }
361
362        }
363
364        /**
365         * Encapsulates a HTTP request/response pair.
366         */
367        protected static class HttpConnection {
368
369                private final long createTime;
370
371                private final ServerHttpRequest request;
372
373                private final ServerHttpResponse response;
374
375                private ServerHttpAsyncRequestControl async;
376
377                private volatile boolean complete = false;
378
379                public HttpConnection(ServerHttpRequest request, ServerHttpResponse response) {
380                        this.createTime = System.currentTimeMillis();
381                        this.request = request;
382                        this.response = response;
383                        this.async = startAsync();
384                }
385
386                /**
387                 * Start asynchronous support or if unavailable return {@code null} to cause
388                 * {@link #waitForResponse()} to block.
389                 * @return the async request control
390                 */
391                protected ServerHttpAsyncRequestControl startAsync() {
392                        try {
393                                // Try to use async to save blocking
394                                ServerHttpAsyncRequestControl async = this.request
395                                                .getAsyncRequestControl(this.response);
396                                async.start();
397                                return async;
398                        }
399                        catch (Exception ex) {
400                                return null;
401                        }
402                }
403
404                /**
405                 * Return the underlying request.
406                 * @return the request
407                 */
408                public final ServerHttpRequest getRequest() {
409                        return this.request;
410                }
411
412                /**
413                 * Return the underlying response.
414                 * @return the response
415                 */
416                protected final ServerHttpResponse getResponse() {
417                        return this.response;
418                }
419
420                /**
421                 * Determine if a connection is older than the specified time.
422                 * @param time the time to check
423                 * @return {@code true} if the request is older than the time
424                 */
425                public boolean isOlderThan(int time) {
426                        long runningTime = System.currentTimeMillis() - this.createTime;
427                        return (runningTime > time);
428                }
429
430                /**
431                 * Cause the request to block or use asynchronous methods to wait until a response
432                 * is available.
433                 */
434                public void waitForResponse() {
435                        if (this.async == null) {
436                                while (!this.complete) {
437                                        try {
438                                                synchronized (this) {
439                                                        wait(1000);
440                                                }
441                                        }
442                                        catch (InterruptedException ex) {
443                                                Thread.currentThread().interrupt();
444                                        }
445                                }
446                        }
447                }
448
449                /**
450                 * Detect if the request is actually a signal to disconnect.
451                 * @return if the request is a signal to disconnect
452                 */
453                public boolean isDisconnectRequest() {
454                        return DISCONNECT_MEDIA_TYPE
455                                        .equals(this.request.getHeaders().getContentType());
456                }
457
458                /**
459                 * Send a HTTP status response.
460                 * @param status the status to send
461                 * @throws IOException in case of I/O errors
462                 */
463                public void respond(HttpStatus status) throws IOException {
464                        Assert.notNull(status, "Status must not be null");
465                        this.response.setStatusCode(status);
466                        complete();
467                }
468
469                /**
470                 * Send a payload response.
471                 * @param payload the payload to send
472                 * @throws IOException in case of I/O errors
473                 */
474                public void respond(HttpTunnelPayload payload) throws IOException {
475                        Assert.notNull(payload, "Payload must not be null");
476                        this.response.setStatusCode(HttpStatus.OK);
477                        payload.assignTo(this.response);
478                        complete();
479                }
480
481                /**
482                 * Called when a request is complete.
483                 */
484                protected void complete() {
485                        if (this.async != null) {
486                                this.async.complete();
487                        }
488                        else {
489                                synchronized (this) {
490                                        this.complete = true;
491                                        notifyAll();
492                                }
493                        }
494                }
495
496        }
497
498}