001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.devtools.tunnel.client;
018
019import java.io.Closeable;
020import java.io.IOException;
021import java.net.InetSocketAddress;
022import java.net.ServerSocket;
023import java.nio.ByteBuffer;
024import java.nio.channels.AsynchronousCloseException;
025import java.nio.channels.ServerSocketChannel;
026import java.nio.channels.SocketChannel;
027import java.nio.channels.WritableByteChannel;
028
029import org.apache.commons.logging.Log;
030import org.apache.commons.logging.LogFactory;
031
032import org.springframework.beans.factory.SmartInitializingSingleton;
033import org.springframework.util.Assert;
034
035/**
036 * The client side component of a socket tunnel. Starts a {@link ServerSocket} of the
037 * specified port for local clients to connect to.
038 *
039 * @author Phillip Webb
040 * @author Andy Wilkinson
041 * @since 1.3.0
042 */
043public class TunnelClient implements SmartInitializingSingleton {
044
045        private static final int BUFFER_SIZE = 1024 * 100;
046
047        private static final Log logger = LogFactory.getLog(TunnelClient.class);
048
049        private final TunnelClientListeners listeners = new TunnelClientListeners();
050
051        private final Object monitor = new Object();
052
053        private final int listenPort;
054
055        private final TunnelConnection tunnelConnection;
056
057        private ServerThread serverThread;
058
059        public TunnelClient(int listenPort, TunnelConnection tunnelConnection) {
060                Assert.isTrue(listenPort >= 0, "ListenPort must be greater than or equal to 0");
061                Assert.notNull(tunnelConnection, "TunnelConnection must not be null");
062                this.listenPort = listenPort;
063                this.tunnelConnection = tunnelConnection;
064        }
065
066        @Override
067        public void afterSingletonsInstantiated() {
068                synchronized (this.monitor) {
069                        if (this.serverThread == null) {
070                                try {
071                                        start();
072                                }
073                                catch (IOException ex) {
074                                        throw new IllegalStateException(ex);
075                                }
076                        }
077                }
078        }
079
080        /**
081         * Start the client and accept incoming connections.
082         * @return the port on which the client is listening
083         * @throws IOException in case of I/O errors
084         */
085        public int start() throws IOException {
086                synchronized (this.monitor) {
087                        Assert.state(this.serverThread == null, "Server already started");
088                        ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
089                        serverSocketChannel.socket().bind(new InetSocketAddress(this.listenPort));
090                        int port = serverSocketChannel.socket().getLocalPort();
091                        logger.trace("Listening for TCP traffic to tunnel on port " + port);
092                        this.serverThread = new ServerThread(serverSocketChannel);
093                        this.serverThread.start();
094                        return port;
095                }
096        }
097
098        /**
099         * Stop the client, disconnecting any servers.
100         * @throws IOException in case of I/O errors
101         */
102        public void stop() throws IOException {
103                synchronized (this.monitor) {
104                        if (this.serverThread != null) {
105                                this.serverThread.close();
106                                try {
107                                        this.serverThread.join(2000);
108                                }
109                                catch (InterruptedException ex) {
110                                        Thread.currentThread().interrupt();
111                                }
112                                this.serverThread = null;
113                        }
114                }
115        }
116
117        protected final ServerThread getServerThread() {
118                synchronized (this.monitor) {
119                        return this.serverThread;
120                }
121        }
122
123        public void addListener(TunnelClientListener listener) {
124                this.listeners.addListener(listener);
125        }
126
127        public void removeListener(TunnelClientListener listener) {
128                this.listeners.removeListener(listener);
129        }
130
131        /**
132         * The main server thread.
133         */
134        protected class ServerThread extends Thread {
135
136                private final ServerSocketChannel serverSocketChannel;
137
138                private boolean acceptConnections = true;
139
140                public ServerThread(ServerSocketChannel serverSocketChannel) {
141                        this.serverSocketChannel = serverSocketChannel;
142                        setName("Tunnel Server");
143                        setDaemon(true);
144                }
145
146                public void close() throws IOException {
147                        logger.trace("Closing tunnel client on port "
148                                        + this.serverSocketChannel.socket().getLocalPort());
149                        this.serverSocketChannel.close();
150                        this.acceptConnections = false;
151                        interrupt();
152                }
153
154                @Override
155                public void run() {
156                        try {
157                                while (this.acceptConnections) {
158                                        try (SocketChannel socket = this.serverSocketChannel.accept()) {
159                                                handleConnection(socket);
160                                        }
161                                        catch (AsynchronousCloseException ex) {
162                                                // Connection has been closed. Keep the server running
163                                        }
164                                }
165                        }
166                        catch (Exception ex) {
167                                logger.trace("Unexpected exception from tunnel client", ex);
168                        }
169                }
170
171                private void handleConnection(SocketChannel socketChannel) throws Exception {
172                        Closeable closeable = new SocketCloseable(socketChannel);
173                        TunnelClient.this.listeners.fireOpenEvent(socketChannel);
174                        try (WritableByteChannel outputChannel = TunnelClient.this.tunnelConnection
175                                        .open(socketChannel, closeable)) {
176                                logger.trace("Accepted connection to tunnel client from "
177                                                + socketChannel.socket().getRemoteSocketAddress());
178                                while (true) {
179                                        ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
180                                        int amountRead = socketChannel.read(buffer);
181                                        if (amountRead == -1) {
182                                                return;
183                                        }
184                                        if (amountRead > 0) {
185                                                buffer.flip();
186                                                outputChannel.write(buffer);
187                                        }
188                                }
189                        }
190                }
191
192                protected void stopAcceptingConnections() {
193                        this.acceptConnections = false;
194                }
195
196        }
197
198        /**
199         * {@link Closeable} used to close a {@link SocketChannel} and fire an event.
200         */
201        private class SocketCloseable implements Closeable {
202
203                private final SocketChannel socketChannel;
204
205                private boolean closed = false;
206
207                SocketCloseable(SocketChannel socketChannel) {
208                        this.socketChannel = socketChannel;
209                }
210
211                @Override
212                public void close() throws IOException {
213                        if (!this.closed) {
214                                this.socketChannel.close();
215                                TunnelClient.this.listeners.fireCloseEvent(this.socketChannel);
216                                this.closed = true;
217                        }
218                }
219
220        }
221
222}