001/*
002 * Copyright 2012-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.devtools.tunnel.server;
018
019import java.io.IOException;
020import java.net.InetSocketAddress;
021import java.net.SocketAddress;
022import java.nio.ByteBuffer;
023import java.nio.channels.ByteChannel;
024import java.nio.channels.Channels;
025import java.nio.channels.ReadableByteChannel;
026import java.nio.channels.SocketChannel;
027
028import org.apache.commons.logging.Log;
029import org.apache.commons.logging.LogFactory;
030
031import org.springframework.util.Assert;
032
033/**
034 * Socket based {@link TargetServerConnection}.
035 *
036 * @author Phillip Webb
037 * @since 1.3.0
038 */
039public class SocketTargetServerConnection implements TargetServerConnection {
040
041        private static final Log logger = LogFactory
042                        .getLog(SocketTargetServerConnection.class);
043
044        private final PortProvider portProvider;
045
046        /**
047         * Create a new {@link SocketTargetServerConnection}.
048         * @param portProvider the port provider
049         */
050        public SocketTargetServerConnection(PortProvider portProvider) {
051                Assert.notNull(portProvider, "PortProvider must not be null");
052                this.portProvider = portProvider;
053        }
054
055        @Override
056        public ByteChannel open(int socketTimeout) throws IOException {
057                SocketAddress address = new InetSocketAddress(this.portProvider.getPort());
058                logger.trace("Opening tunnel connection to target server on " + address);
059                SocketChannel channel = SocketChannel.open(address);
060                channel.socket().setSoTimeout(socketTimeout);
061                return new TimeoutAwareChannel(channel);
062        }
063
064        /**
065         * Wrapper to expose the {@link SocketChannel} in such a way that
066         * {@code SocketTimeoutExceptions} are still thrown from read methods.
067         */
068        private static class TimeoutAwareChannel implements ByteChannel {
069
070                private final SocketChannel socketChannel;
071
072                private final ReadableByteChannel readChannel;
073
074                TimeoutAwareChannel(SocketChannel socketChannel) throws IOException {
075                        this.socketChannel = socketChannel;
076                        this.readChannel = Channels
077                                        .newChannel(socketChannel.socket().getInputStream());
078                }
079
080                @Override
081                public int read(ByteBuffer dst) throws IOException {
082                        return this.readChannel.read(dst);
083                }
084
085                @Override
086                public int write(ByteBuffer src) throws IOException {
087                        return this.socketChannel.write(src);
088                }
089
090                @Override
091                public boolean isOpen() {
092                        return this.socketChannel.isOpen();
093                }
094
095                @Override
096                public void close() throws IOException {
097                        this.socketChannel.close();
098                }
099
100        }
101
102}