001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.handler; 018 019import java.io.IOException; 020import java.util.Queue; 021import java.util.concurrent.LinkedBlockingQueue; 022import java.util.concurrent.atomic.AtomicInteger; 023import java.util.concurrent.locks.Lock; 024import java.util.concurrent.locks.ReentrantLock; 025 026import org.apache.commons.logging.Log; 027import org.apache.commons.logging.LogFactory; 028 029import org.springframework.web.socket.CloseStatus; 030import org.springframework.web.socket.WebSocketMessage; 031import org.springframework.web.socket.WebSocketSession; 032 033/** 034 * Wrap a {@link org.springframework.web.socket.WebSocketSession WebSocketSession} 035 * to guarantee only one thread can send messages at a time. 036 * 037 * <p>If a send is slow, subsequent attempts to send more messages from other threads 038 * will not be able to acquire the flush lock and messages will be buffered instead. 039 * At that time, the specified buffer-size limit and send-time limit will be checked 040 * and the session will be closed if the limits are exceeded. 041 * 042 * @author Rossen Stoyanchev 043 * @author Juergen Hoeller 044 * @since 4.0.3 045 */ 046public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator { 047 048 private static final Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class); 049 050 051 private final int sendTimeLimit; 052 053 private final int bufferSizeLimit; 054 055 private final OverflowStrategy overflowStrategy; 056 057 private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<>(); 058 059 private final AtomicInteger bufferSize = new AtomicInteger(); 060 061 private volatile long sendStartTime; 062 063 private volatile boolean limitExceeded; 064 065 private volatile boolean closeInProgress; 066 067 private final Lock flushLock = new ReentrantLock(); 068 069 private final Lock closeLock = new ReentrantLock(); 070 071 072 /** 073 * Basic constructor. 074 * @param delegate the {@code WebSocketSession} to delegate to 075 * @param sendTimeLimit the send-time limit (milliseconds) 076 * @param bufferSizeLimit the buffer-size limit (number of bytes) 077 */ 078 public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) { 079 this(delegate, sendTimeLimit, bufferSizeLimit, OverflowStrategy.TERMINATE); 080 } 081 082 /** 083 * Constructor that also specifies the overflow strategy to use. 084 * @param delegate the {@code WebSocketSession} to delegate to 085 * @param sendTimeLimit the send-time limit (milliseconds) 086 * @param bufferSizeLimit the buffer-size limit (number of bytes) 087 * @param overflowStrategy the overflow strategy to use; by default the 088 * session is terminated. 089 * @since 5.1 090 */ 091 public ConcurrentWebSocketSessionDecorator( 092 WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit, OverflowStrategy overflowStrategy) { 093 094 super(delegate); 095 this.sendTimeLimit = sendTimeLimit; 096 this.bufferSizeLimit = bufferSizeLimit; 097 this.overflowStrategy = overflowStrategy; 098 } 099 100 101 /** 102 * Return the configured send-time limit (milliseconds). 103 * @since 4.3.13 104 */ 105 public int getSendTimeLimit() { 106 return this.sendTimeLimit; 107 } 108 109 /** 110 * Return the configured buffer-size limit (number of bytes). 111 * @since 4.3.13 112 */ 113 public int getBufferSizeLimit() { 114 return this.bufferSizeLimit; 115 } 116 117 /** 118 * Return the current buffer size (number of bytes). 119 */ 120 public int getBufferSize() { 121 return this.bufferSize.get(); 122 } 123 124 /** 125 * Return the time (milliseconds) since the current send started, 126 * or 0 if no send is currently in progress. 127 */ 128 public long getTimeSinceSendStarted() { 129 long start = this.sendStartTime; 130 return (start > 0 ? (System.currentTimeMillis() - start) : 0); 131 } 132 133 134 @Override 135 public void sendMessage(WebSocketMessage<?> message) throws IOException { 136 if (shouldNotSend()) { 137 return; 138 } 139 140 this.buffer.add(message); 141 this.bufferSize.addAndGet(message.getPayloadLength()); 142 143 do { 144 if (!tryFlushMessageBuffer()) { 145 if (logger.isTraceEnabled()) { 146 logger.trace(String.format("Another send already in progress: " + 147 "session id '%s':, \"in-progress\" send time %d (ms), buffer size %d bytes", 148 getId(), getTimeSinceSendStarted(), getBufferSize())); 149 } 150 checkSessionLimits(); 151 break; 152 } 153 } 154 while (!this.buffer.isEmpty() && !shouldNotSend()); 155 } 156 157 private boolean shouldNotSend() { 158 return (this.limitExceeded || this.closeInProgress); 159 } 160 161 private boolean tryFlushMessageBuffer() throws IOException { 162 if (this.flushLock.tryLock()) { 163 try { 164 while (true) { 165 WebSocketMessage<?> message = this.buffer.poll(); 166 if (message == null || shouldNotSend()) { 167 break; 168 } 169 this.bufferSize.addAndGet(-message.getPayloadLength()); 170 this.sendStartTime = System.currentTimeMillis(); 171 getDelegate().sendMessage(message); 172 this.sendStartTime = 0; 173 } 174 } 175 finally { 176 this.sendStartTime = 0; 177 this.flushLock.unlock(); 178 } 179 return true; 180 } 181 return false; 182 } 183 184 private void checkSessionLimits() { 185 if (!shouldNotSend() && this.closeLock.tryLock()) { 186 try { 187 if (getTimeSinceSendStarted() > getSendTimeLimit()) { 188 String format = "Send time %d (ms) for session '%s' exceeded the allowed limit %d"; 189 String reason = String.format(format, getTimeSinceSendStarted(), getId(), getSendTimeLimit()); 190 limitExceeded(reason); 191 } 192 else if (getBufferSize() > getBufferSizeLimit()) { 193 switch (this.overflowStrategy) { 194 case TERMINATE: 195 String format = "Buffer size %d bytes for session '%s' exceeds the allowed limit %d"; 196 String reason = String.format(format, getBufferSize(), getId(), getBufferSizeLimit()); 197 limitExceeded(reason); 198 break; 199 case DROP: 200 int i = 0; 201 while (getBufferSize() > getBufferSizeLimit()) { 202 WebSocketMessage<?> message = this.buffer.poll(); 203 if (message == null) { 204 break; 205 } 206 this.bufferSize.addAndGet(-message.getPayloadLength()); 207 i++; 208 } 209 if (logger.isDebugEnabled()) { 210 logger.debug("Dropped " + i + " messages, buffer size: " + getBufferSize()); 211 } 212 break; 213 default: 214 // Should never happen.. 215 throw new IllegalStateException("Unexpected OverflowStrategy: " + this.overflowStrategy); 216 } 217 } 218 } 219 finally { 220 this.closeLock.unlock(); 221 } 222 } 223 } 224 225 private void limitExceeded(String reason) { 226 this.limitExceeded = true; 227 throw new SessionLimitExceededException(reason, CloseStatus.SESSION_NOT_RELIABLE); 228 } 229 230 @Override 231 public void close(CloseStatus status) throws IOException { 232 this.closeLock.lock(); 233 try { 234 if (this.closeInProgress) { 235 return; 236 } 237 if (!CloseStatus.SESSION_NOT_RELIABLE.equals(status)) { 238 try { 239 checkSessionLimits(); 240 } 241 catch (SessionLimitExceededException ex) { 242 // Ignore 243 } 244 if (this.limitExceeded) { 245 if (logger.isDebugEnabled()) { 246 logger.debug("Changing close status " + status + " to SESSION_NOT_RELIABLE."); 247 } 248 status = CloseStatus.SESSION_NOT_RELIABLE; 249 } 250 } 251 this.closeInProgress = true; 252 super.close(status); 253 } 254 finally { 255 this.closeLock.unlock(); 256 } 257 } 258 259 260 @Override 261 public String toString() { 262 return getDelegate().toString(); 263 } 264 265 266 /** 267 * Enum for options of what to do when the buffer fills up. 268 * @since 5.1 269 */ 270 public enum OverflowStrategy { 271 272 /** 273 * Throw {@link SessionLimitExceededException} that would will result 274 * in the session being terminated. 275 */ 276 TERMINATE, 277 278 /** 279 * Drop the oldest messages from the buffer. 280 */ 281 DROP 282 } 283 284}