001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.handler;
018
019import java.io.IOException;
020import java.util.Queue;
021import java.util.concurrent.LinkedBlockingQueue;
022import java.util.concurrent.atomic.AtomicInteger;
023import java.util.concurrent.locks.Lock;
024import java.util.concurrent.locks.ReentrantLock;
025
026import org.apache.commons.logging.Log;
027import org.apache.commons.logging.LogFactory;
028
029import org.springframework.web.socket.CloseStatus;
030import org.springframework.web.socket.WebSocketMessage;
031import org.springframework.web.socket.WebSocketSession;
032
033/**
034 * Wrap a {@link org.springframework.web.socket.WebSocketSession WebSocketSession}
035 * to guarantee only one thread can send messages at a time.
036 *
037 * <p>If a send is slow, subsequent attempts to send more messages from other threads
038 * will not be able to acquire the flush lock and messages will be buffered instead.
039 * At that time, the specified buffer-size limit and send-time limit will be checked
040 * and the session will be closed if the limits are exceeded.
041 *
042 * @author Rossen Stoyanchev
043 * @author Juergen Hoeller
044 * @since 4.0.3
045 */
046public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator {
047
048        private static final Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
049
050
051        private final int sendTimeLimit;
052
053        private final int bufferSizeLimit;
054
055        private final OverflowStrategy overflowStrategy;
056
057        private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<>();
058
059        private final AtomicInteger bufferSize = new AtomicInteger();
060
061        private volatile long sendStartTime;
062
063        private volatile boolean limitExceeded;
064
065        private volatile boolean closeInProgress;
066
067        private final Lock flushLock = new ReentrantLock();
068
069        private final Lock closeLock = new ReentrantLock();
070
071
072        /**
073         * Basic constructor.
074         * @param delegate the {@code WebSocketSession} to delegate to
075         * @param sendTimeLimit the send-time limit (milliseconds)
076         * @param bufferSizeLimit the buffer-size limit (number of bytes)
077         */
078        public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) {
079                this(delegate, sendTimeLimit, bufferSizeLimit, OverflowStrategy.TERMINATE);
080        }
081
082        /**
083         * Constructor that also specifies the overflow strategy to use.
084         * @param delegate the {@code WebSocketSession} to delegate to
085         * @param sendTimeLimit the send-time limit (milliseconds)
086         * @param bufferSizeLimit the buffer-size limit (number of bytes)
087         * @param overflowStrategy the overflow strategy to use; by default the
088         * session is terminated.
089         * @since 5.1
090         */
091        public ConcurrentWebSocketSessionDecorator(
092                        WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit, OverflowStrategy overflowStrategy) {
093
094                super(delegate);
095                this.sendTimeLimit = sendTimeLimit;
096                this.bufferSizeLimit = bufferSizeLimit;
097                this.overflowStrategy = overflowStrategy;
098        }
099
100
101        /**
102         * Return the configured send-time limit (milliseconds).
103         * @since 4.3.13
104         */
105        public int getSendTimeLimit() {
106                return this.sendTimeLimit;
107        }
108
109        /**
110         * Return the configured buffer-size limit (number of bytes).
111         * @since 4.3.13
112         */
113        public int getBufferSizeLimit() {
114                return this.bufferSizeLimit;
115        }
116
117        /**
118         * Return the current buffer size (number of bytes).
119         */
120        public int getBufferSize() {
121                return this.bufferSize.get();
122        }
123
124        /**
125         * Return the time (milliseconds) since the current send started,
126         * or 0 if no send is currently in progress.
127         */
128        public long getTimeSinceSendStarted() {
129                long start = this.sendStartTime;
130                return (start > 0 ? (System.currentTimeMillis() - start) : 0);
131        }
132
133
134        @Override
135        public void sendMessage(WebSocketMessage<?> message) throws IOException {
136                if (shouldNotSend()) {
137                        return;
138                }
139
140                this.buffer.add(message);
141                this.bufferSize.addAndGet(message.getPayloadLength());
142
143                do {
144                        if (!tryFlushMessageBuffer()) {
145                                if (logger.isTraceEnabled()) {
146                                        logger.trace(String.format("Another send already in progress: " +
147                                                        "session id '%s':, \"in-progress\" send time %d (ms), buffer size %d bytes",
148                                                        getId(), getTimeSinceSendStarted(), getBufferSize()));
149                                }
150                                checkSessionLimits();
151                                break;
152                        }
153                }
154                while (!this.buffer.isEmpty() && !shouldNotSend());
155        }
156
157        private boolean shouldNotSend() {
158                return (this.limitExceeded || this.closeInProgress);
159        }
160
161        private boolean tryFlushMessageBuffer() throws IOException {
162                if (this.flushLock.tryLock()) {
163                        try {
164                                while (true) {
165                                        WebSocketMessage<?> message = this.buffer.poll();
166                                        if (message == null || shouldNotSend()) {
167                                                break;
168                                        }
169                                        this.bufferSize.addAndGet(-message.getPayloadLength());
170                                        this.sendStartTime = System.currentTimeMillis();
171                                        getDelegate().sendMessage(message);
172                                        this.sendStartTime = 0;
173                                }
174                        }
175                        finally {
176                                this.sendStartTime = 0;
177                                this.flushLock.unlock();
178                        }
179                        return true;
180                }
181                return false;
182        }
183
184        private void checkSessionLimits() {
185                if (!shouldNotSend() && this.closeLock.tryLock()) {
186                        try {
187                                if (getTimeSinceSendStarted() > getSendTimeLimit()) {
188                                        String format = "Send time %d (ms) for session '%s' exceeded the allowed limit %d";
189                                        String reason = String.format(format, getTimeSinceSendStarted(), getId(), getSendTimeLimit());
190                                        limitExceeded(reason);
191                                }
192                                else if (getBufferSize() > getBufferSizeLimit()) {
193                                        switch (this.overflowStrategy) {
194                                                case TERMINATE:
195                                                        String format = "Buffer size %d bytes for session '%s' exceeds the allowed limit %d";
196                                                        String reason = String.format(format, getBufferSize(), getId(), getBufferSizeLimit());
197                                                        limitExceeded(reason);
198                                                        break;
199                                                case DROP:
200                                                        int i = 0;
201                                                        while (getBufferSize() > getBufferSizeLimit()) {
202                                                                WebSocketMessage<?> message = this.buffer.poll();
203                                                                if (message == null) {
204                                                                        break;
205                                                                }
206                                                                this.bufferSize.addAndGet(-message.getPayloadLength());
207                                                                i++;
208                                                        }
209                                                        if (logger.isDebugEnabled()) {
210                                                                logger.debug("Dropped " + i + " messages, buffer size: " + getBufferSize());
211                                                        }
212                                                        break;
213                                                default:
214                                                        // Should never happen..
215                                                        throw new IllegalStateException("Unexpected OverflowStrategy: " + this.overflowStrategy);
216                                        }
217                                }
218                        }
219                        finally {
220                                this.closeLock.unlock();
221                        }
222                }
223        }
224
225        private void limitExceeded(String reason) {
226                this.limitExceeded = true;
227                throw new SessionLimitExceededException(reason, CloseStatus.SESSION_NOT_RELIABLE);
228        }
229
230        @Override
231        public void close(CloseStatus status) throws IOException {
232                this.closeLock.lock();
233                try {
234                        if (this.closeInProgress) {
235                                return;
236                        }
237                        if (!CloseStatus.SESSION_NOT_RELIABLE.equals(status)) {
238                                try {
239                                        checkSessionLimits();
240                                }
241                                catch (SessionLimitExceededException ex) {
242                                        // Ignore
243                                }
244                                if (this.limitExceeded) {
245                                        if (logger.isDebugEnabled()) {
246                                                logger.debug("Changing close status " + status + " to SESSION_NOT_RELIABLE.");
247                                        }
248                                        status = CloseStatus.SESSION_NOT_RELIABLE;
249                                }
250                        }
251                        this.closeInProgress = true;
252                        super.close(status);
253                }
254                finally {
255                        this.closeLock.unlock();
256                }
257        }
258
259
260        @Override
261        public String toString() {
262                return getDelegate().toString();
263        }
264
265
266        /**
267         * Enum for options of what to do when the buffer fills up.
268         * @since 5.1
269         */
270        public enum OverflowStrategy {
271
272                /**
273                 * Throw {@link SessionLimitExceededException} that would will result
274                 * in the session being terminated.
275                 */
276                TERMINATE,
277
278                /**
279                 * Drop the oldest messages from the buffer.
280                 */
281                DROP
282        }
283
284}