001/*
002 * Copyright 2012-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.devtools.tunnel.payload;
018
019import java.io.IOException;
020import java.nio.channels.WritableByteChannel;
021import java.util.HashMap;
022import java.util.Map;
023
024import org.springframework.util.Assert;
025
026/**
027 * Utility class that forwards {@link HttpTunnelPayload} instances to a destination
028 * channel, respecting sequence order.
029 *
030 * @author Phillip Webb
031 * @since 1.3.0
032 */
033public class HttpTunnelPayloadForwarder {
034
035        private static final int MAXIMUM_QUEUE_SIZE = 100;
036
037        private final Map<Long, HttpTunnelPayload> queue = new HashMap<>();
038
039        private final Object monitor = new Object();
040
041        private final WritableByteChannel targetChannel;
042
043        private long lastRequestSeq = 0;
044
045        /**
046         * Create a new {@link HttpTunnelPayloadForwarder} instance.
047         * @param targetChannel the target channel
048         */
049        public HttpTunnelPayloadForwarder(WritableByteChannel targetChannel) {
050                Assert.notNull(targetChannel, "TargetChannel must not be null");
051                this.targetChannel = targetChannel;
052        }
053
054        public void forward(HttpTunnelPayload payload) throws IOException {
055                synchronized (this.monitor) {
056                        long seq = payload.getSequence();
057                        if (this.lastRequestSeq != seq - 1) {
058                                Assert.state(this.queue.size() < MAXIMUM_QUEUE_SIZE,
059                                                "Too many messages queued");
060                                this.queue.put(seq, payload);
061                                return;
062                        }
063                        payload.logOutgoing();
064                        payload.writeTo(this.targetChannel);
065                        this.lastRequestSeq = seq;
066                        HttpTunnelPayload queuedItem = this.queue.get(seq + 1);
067                        if (queuedItem != null) {
068                                forward(queuedItem);
069                        }
070                }
071        }
072
073}