001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.handler;
018
019import java.io.IOException;
020import java.util.Queue;
021import java.util.concurrent.LinkedBlockingQueue;
022import java.util.concurrent.atomic.AtomicInteger;
023import java.util.concurrent.locks.Lock;
024import java.util.concurrent.locks.ReentrantLock;
025
026import org.apache.commons.logging.Log;
027import org.apache.commons.logging.LogFactory;
028
029import org.springframework.web.socket.CloseStatus;
030import org.springframework.web.socket.WebSocketMessage;
031import org.springframework.web.socket.WebSocketSession;
032
033/**
034 * Wrap a {@link org.springframework.web.socket.WebSocketSession WebSocketSession}
035 * to guarantee only one thread can send messages at a time.
036 *
037 * <p>If a send is slow, subsequent attempts to send more messages from other threads
038 * will not be able to acquire the flush lock and messages will be buffered instead.
039 * At that time, the specified buffer-size limit and send-time limit will be checked
040 * and the session will be closed if the limits are exceeded.
041 *
042 * @author Rossen Stoyanchev
043 * @author Juergen Hoeller
044 * @since 4.0.3
045 */
046public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator {
047
048        private static final Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
049
050
051        private final int sendTimeLimit;
052
053        private final int bufferSizeLimit;
054
055        private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<WebSocketMessage<?>>();
056
057        private final AtomicInteger bufferSize = new AtomicInteger();
058
059        private volatile long sendStartTime;
060
061        private volatile boolean limitExceeded;
062
063        private volatile boolean closeInProgress;
064
065        private final Lock flushLock = new ReentrantLock();
066
067        private final Lock closeLock = new ReentrantLock();
068
069
070        /**
071         * Create a new {@code ConcurrentWebSocketSessionDecorator}.
072         * @param delegate the {@code WebSocketSession} to delegate to
073         * @param sendTimeLimit the send-time limit (milliseconds)
074         * @param bufferSizeLimit the buffer-size limit (number of bytes)
075         */
076        public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) {
077                super(delegate);
078                this.sendTimeLimit = sendTimeLimit;
079                this.bufferSizeLimit = bufferSizeLimit;
080        }
081
082
083        /**
084         * Return the configured send-time limit (milliseconds).
085         * @since 4.3.13
086         */
087        public int getSendTimeLimit() {
088                return this.sendTimeLimit;
089        }
090
091        /**
092         * Return the configured buffer-size limit (number of bytes).
093         * @since 4.3.13
094         */
095        public int getBufferSizeLimit() {
096                return this.bufferSizeLimit;
097        }
098
099        /**
100         * Return the current buffer size (number of bytes).
101         */
102        public int getBufferSize() {
103                return this.bufferSize.get();
104        }
105
106        /**
107         * Return the time (milliseconds) since the current send started,
108         * or 0 if no send is currently in progress.
109         */
110        public long getTimeSinceSendStarted() {
111                long start = this.sendStartTime;
112                return (start > 0 ? (System.currentTimeMillis() - start) : 0);
113        }
114
115
116        @Override
117        public void sendMessage(WebSocketMessage<?> message) throws IOException {
118                if (shouldNotSend()) {
119                        return;
120                }
121
122                this.buffer.add(message);
123                this.bufferSize.addAndGet(message.getPayloadLength());
124
125                do {
126                        if (!tryFlushMessageBuffer()) {
127                                if (logger.isTraceEnabled()) {
128                                        logger.trace(String.format("Another send already in progress: " +
129                                                        "session id '%s':, \"in-progress\" send time %d (ms), buffer size %d bytes",
130                                                        getId(), getTimeSinceSendStarted(), getBufferSize()));
131                                }
132                                checkSessionLimits();
133                                break;
134                        }
135                }
136                while (!this.buffer.isEmpty() && !shouldNotSend());
137        }
138
139        private boolean shouldNotSend() {
140                return (this.limitExceeded || this.closeInProgress);
141        }
142
143        private boolean tryFlushMessageBuffer() throws IOException {
144                if (this.flushLock.tryLock()) {
145                        try {
146                                while (true) {
147                                        WebSocketMessage<?> message = this.buffer.poll();
148                                        if (message == null || shouldNotSend()) {
149                                                break;
150                                        }
151                                        this.bufferSize.addAndGet(message.getPayloadLength() * -1);
152                                        this.sendStartTime = System.currentTimeMillis();
153                                        getDelegate().sendMessage(message);
154                                        this.sendStartTime = 0;
155                                }
156                        }
157                        finally {
158                                this.sendStartTime = 0;
159                                this.flushLock.unlock();
160                        }
161                        return true;
162                }
163                return false;
164        }
165
166        private void checkSessionLimits() {
167                if (!shouldNotSend() && this.closeLock.tryLock()) {
168                        try {
169                                if (getTimeSinceSendStarted() > getSendTimeLimit()) {
170                                        String format = "Message send time %d (ms) for session '%s' exceeded the allowed limit %d";
171                                        String reason = String.format(format, getTimeSinceSendStarted(), getId(), getSendTimeLimit());
172                                        limitExceeded(reason);
173                                }
174                                else if (getBufferSize() > getBufferSizeLimit()) {
175                                        String format = "The send buffer size %d bytes for session '%s' exceeded the allowed limit %d";
176                                        String reason = String.format(format, getBufferSize(), getId(), getBufferSizeLimit());
177                                        limitExceeded(reason);
178                                }
179                        }
180                        finally {
181                                this.closeLock.unlock();
182                        }
183                }
184        }
185
186        private void limitExceeded(String reason) {
187                this.limitExceeded = true;
188                throw new SessionLimitExceededException(reason, CloseStatus.SESSION_NOT_RELIABLE);
189        }
190
191        @Override
192        public void close(CloseStatus status) throws IOException {
193                this.closeLock.lock();
194                try {
195                        if (this.closeInProgress) {
196                                return;
197                        }
198                        if (!CloseStatus.SESSION_NOT_RELIABLE.equals(status)) {
199                                try {
200                                        checkSessionLimits();
201                                }
202                                catch (SessionLimitExceededException ex) {
203                                        // Ignore
204                                }
205                                if (this.limitExceeded) {
206                                        if (logger.isDebugEnabled()) {
207                                                logger.debug("Changing close status " + status + " to SESSION_NOT_RELIABLE.");
208                                        }
209                                        status = CloseStatus.SESSION_NOT_RELIABLE;
210                                }
211                        }
212                        this.closeInProgress = true;
213                        super.close(status);
214                }
215                finally {
216                        this.closeLock.unlock();
217                }
218        }
219
220
221        @Override
222        public String toString() {
223                return getDelegate().toString();
224        }
225
226}