001/*
002 * Copyright 2002-2019 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.IOException;
020import java.net.URI;
021import java.security.Principal;
022import java.util.Map;
023import java.util.concurrent.ConcurrentHashMap;
024
025import org.apache.commons.logging.Log;
026import org.apache.commons.logging.LogFactory;
027
028import org.springframework.http.HttpHeaders;
029import org.springframework.lang.Nullable;
030import org.springframework.util.Assert;
031import org.springframework.util.concurrent.SettableListenableFuture;
032import org.springframework.web.socket.CloseStatus;
033import org.springframework.web.socket.TextMessage;
034import org.springframework.web.socket.WebSocketHandler;
035import org.springframework.web.socket.WebSocketMessage;
036import org.springframework.web.socket.WebSocketSession;
037import org.springframework.web.socket.sockjs.frame.SockJsFrame;
038import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
039
040/**
041 * Base class for SockJS client implementations of {@link WebSocketSession}.
042 * Provides processing of incoming SockJS message frames and delegates lifecycle
043 * events and messages to the (application) {@link WebSocketHandler}.
044 * Sub-classes implement actual send as well as disconnect logic.
045 *
046 * @author Rossen Stoyanchev
047 * @author Juergen Hoeller
048 * @since 4.1
049 */
050public abstract class AbstractClientSockJsSession implements WebSocketSession {
051
052        protected final Log logger = LogFactory.getLog(getClass());
053
054        private final TransportRequest request;
055
056        private final WebSocketHandler webSocketHandler;
057
058        private final SettableListenableFuture<WebSocketSession> connectFuture;
059
060        private final Map<String, Object> attributes = new ConcurrentHashMap<>();
061
062        @Nullable
063        private volatile State state = State.NEW;
064
065        @Nullable
066        private volatile CloseStatus closeStatus;
067
068
069        protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler,
070                        SettableListenableFuture<WebSocketSession> connectFuture) {
071
072                Assert.notNull(request, "'request' is required");
073                Assert.notNull(handler, "'handler' is required");
074                Assert.notNull(connectFuture, "'connectFuture' is required");
075                this.request = request;
076                this.webSocketHandler = handler;
077                this.connectFuture = connectFuture;
078        }
079
080
081        @Override
082        public String getId() {
083                return this.request.getSockJsUrlInfo().getSessionId();
084        }
085
086        @Override
087        public URI getUri() {
088                return this.request.getSockJsUrlInfo().getSockJsUrl();
089        }
090
091        @Override
092        public HttpHeaders getHandshakeHeaders() {
093                return this.request.getHandshakeHeaders();
094        }
095
096        @Override
097        public Map<String, Object> getAttributes() {
098                return this.attributes;
099        }
100
101        @Override
102        public Principal getPrincipal() {
103                return this.request.getUser();
104        }
105
106        public SockJsMessageCodec getMessageCodec() {
107                return this.request.getMessageCodec();
108        }
109
110        public WebSocketHandler getWebSocketHandler() {
111                return this.webSocketHandler;
112        }
113
114        /**
115         * Return a timeout cleanup task to invoke if the SockJS sessions is not
116         * fully established within the retransmission timeout period calculated in
117         * {@code SockJsRequest} based on the duration of the initial SockJS "Info"
118         * request.
119         */
120        Runnable getTimeoutTask() {
121                return new Runnable() {
122                        @Override
123                        public void run() {
124                                try {
125                                        closeInternal(new CloseStatus(2007, "Transport timed out"));
126                                }
127                                catch (Throwable ex) {
128                                        if (logger.isWarnEnabled()) {
129                                                logger.warn("Failed to close " + this + " after transport timeout", ex);
130                                        }
131                                }
132                        }
133                };
134        }
135
136        @Override
137        public boolean isOpen() {
138                return (this.state == State.OPEN);
139        }
140
141        public boolean isDisconnected() {
142                return (this.state == State.CLOSING || this.state == State.CLOSED);
143        }
144
145        @Override
146        public final void sendMessage(WebSocketMessage<?> message) throws IOException {
147                if (!(message instanceof TextMessage)) {
148                        throw new IllegalArgumentException(this + " supports text messages only.");
149                }
150                if (this.state != State.OPEN) {
151                        throw new IllegalStateException(this + " is not open: current state " + this.state);
152                }
153
154                String payload = ((TextMessage) message).getPayload();
155                payload = getMessageCodec().encode(payload);
156                payload = payload.substring(1);  // the client-side doesn't need message framing (letter "a")
157
158                TextMessage messageToSend = new TextMessage(payload);
159                if (logger.isTraceEnabled()) {
160                        logger.trace("Sending message " + messageToSend + " in " + this);
161                }
162                sendInternal(messageToSend);
163        }
164
165        protected abstract void sendInternal(TextMessage textMessage) throws IOException;
166
167        @Override
168        public final void close() throws IOException {
169                close(CloseStatus.NORMAL);
170        }
171
172        @Override
173        public final void close(CloseStatus status) throws IOException {
174                if (!isUserSetStatus(status)) {
175                        throw new IllegalArgumentException("Invalid close status: " + status);
176                }
177                if (logger.isDebugEnabled()) {
178                        logger.debug("Closing session with " +  status + " in " + this);
179                }
180                closeInternal(status);
181        }
182
183        private boolean isUserSetStatus(@Nullable CloseStatus status) {
184                return (status != null && (status.getCode() == 1000 ||
185                                (status.getCode() >= 3000 && status.getCode() <= 4999)));
186        }
187
188        private void silentClose(CloseStatus status) {
189                try {
190                        closeInternal(status);
191                }
192                catch (Throwable ex) {
193                        if (logger.isWarnEnabled()) {
194                                logger.warn("Failed to close " + this, ex);
195                        }
196                }
197        }
198
199        protected void closeInternal(CloseStatus status) throws IOException {
200                if (this.state == null) {
201                        logger.warn("Ignoring close since connect() was never invoked");
202                        return;
203                }
204                if (isDisconnected()) {
205                        if (logger.isDebugEnabled()) {
206                                logger.debug("Ignoring close (already closing or closed): current state " + this.state);
207                        }
208                        return;
209                }
210
211                this.state = State.CLOSING;
212                this.closeStatus = status;
213                disconnect(status);
214        }
215
216        protected abstract void disconnect(CloseStatus status) throws IOException;
217
218        public void handleFrame(String payload) {
219                SockJsFrame frame = new SockJsFrame(payload);
220                switch (frame.getType()) {
221                        case OPEN:
222                                handleOpenFrame();
223                                break;
224                        case HEARTBEAT:
225                                if (logger.isTraceEnabled()) {
226                                        logger.trace("Received heartbeat in " + this);
227                                }
228                                break;
229                        case MESSAGE:
230                                handleMessageFrame(frame);
231                                break;
232                        case CLOSE:
233                                handleCloseFrame(frame);
234                }
235        }
236
237        private void handleOpenFrame() {
238                if (logger.isDebugEnabled()) {
239                        logger.debug("Processing SockJS open frame in " + this);
240                }
241                if (this.state == State.NEW) {
242                        this.state = State.OPEN;
243                        try {
244                                this.webSocketHandler.afterConnectionEstablished(this);
245                                this.connectFuture.set(this);
246                        }
247                        catch (Exception ex) {
248                                if (logger.isErrorEnabled()) {
249                                        logger.error("WebSocketHandler.afterConnectionEstablished threw exception in " + this, ex);
250                                }
251                        }
252                }
253                else {
254                        if (logger.isDebugEnabled()) {
255                                logger.debug("Open frame received in " + getId() + " but we're not connecting (current state " +
256                                                this.state + "). The server might have been restarted and lost track of the session.");
257                        }
258                        silentClose(new CloseStatus(1006, "Server lost session"));
259                }
260        }
261
262        private void handleMessageFrame(SockJsFrame frame) {
263                if (!isOpen()) {
264                        if (logger.isErrorEnabled()) {
265                                logger.error("Ignoring received message due to state " + this.state + " in " + this);
266                        }
267                        return;
268                }
269
270                String[] messages = null;
271                String frameData = frame.getFrameData();
272                if (frameData != null) {
273                        try {
274                                messages = getMessageCodec().decode(frameData);
275                        }
276                        catch (IOException ex) {
277                                if (logger.isErrorEnabled()) {
278                                        logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex);
279                                }
280                                silentClose(CloseStatus.BAD_DATA);
281                                return;
282                        }
283                }
284                if (messages == null) {
285                        return;
286                }
287
288                if (logger.isTraceEnabled()) {
289                        logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this);
290                }
291                for (String message : messages) {
292                        if (isOpen()) {
293                                try {
294                                        this.webSocketHandler.handleMessage(this, new TextMessage(message));
295                                }
296                                catch (Exception ex) {
297                                        logger.error("WebSocketHandler.handleMessage threw an exception on " + frame + " in " + this, ex);
298                                }
299                        }
300                }
301        }
302
303        private void handleCloseFrame(SockJsFrame frame) {
304                CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE;
305                try {
306                        String frameData = frame.getFrameData();
307                        if (frameData != null) {
308                                String[] data = getMessageCodec().decode(frameData);
309                                if (data != null && data.length == 2) {
310                                        closeStatus = new CloseStatus(Integer.parseInt(data[0]), data[1]);
311                                }
312                                if (logger.isDebugEnabled()) {
313                                        logger.debug("Processing SockJS close frame with " + closeStatus + " in " + this);
314                                }
315                        }
316                }
317                catch (IOException ex) {
318                        if (logger.isErrorEnabled()) {
319                                logger.error("Failed to decode data for " + frame + " in " + this, ex);
320                        }
321                }
322                silentClose(closeStatus);
323        }
324
325        public void handleTransportError(Throwable error) {
326                try {
327                        if (logger.isErrorEnabled()) {
328                                logger.error("Transport error in " + this, error);
329                        }
330                        this.webSocketHandler.handleTransportError(this, error);
331                }
332                catch (Throwable ex) {
333                        logger.error("WebSocketHandler.handleTransportError threw an exception", ex);
334                }
335        }
336
337        public void afterTransportClosed(@Nullable CloseStatus closeStatus) {
338                CloseStatus cs = this.closeStatus;
339                if (cs == null) {
340                        cs = closeStatus;
341                        this.closeStatus = closeStatus;
342                }
343                Assert.state(cs != null, "CloseStatus not available");
344                if (logger.isDebugEnabled()) {
345                        logger.debug("Transport closed with " + cs + " in " + this);
346                }
347
348                this.state = State.CLOSED;
349                try {
350                        this.webSocketHandler.afterConnectionClosed(this, cs);
351                }
352                catch (Throwable ex) {
353                        logger.error("WebSocketHandler.afterConnectionClosed threw an exception", ex);
354                }
355        }
356
357        @Override
358        public String toString() {
359                return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]";
360        }
361
362
363        private enum State { NEW, OPEN, CLOSING, CLOSED }
364
365}