001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.sockjs.client;
018
019import java.io.IOException;
020import java.net.URI;
021import java.security.Principal;
022import java.util.Map;
023import java.util.concurrent.ConcurrentHashMap;
024
025import org.apache.commons.logging.Log;
026import org.apache.commons.logging.LogFactory;
027
028import org.springframework.http.HttpHeaders;
029import org.springframework.util.Assert;
030import org.springframework.util.concurrent.SettableListenableFuture;
031import org.springframework.web.socket.CloseStatus;
032import org.springframework.web.socket.TextMessage;
033import org.springframework.web.socket.WebSocketHandler;
034import org.springframework.web.socket.WebSocketMessage;
035import org.springframework.web.socket.WebSocketSession;
036import org.springframework.web.socket.sockjs.frame.SockJsFrame;
037import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
038
039/**
040 * Base class for SockJS client implementations of {@link WebSocketSession}.
041 * Provides processing of incoming SockJS message frames and delegates lifecycle
042 * events and messages to the (application) {@link WebSocketHandler}.
043 * Sub-classes implement actual send as well as disconnect logic.
044 *
045 * @author Rossen Stoyanchev
046 * @author Juergen Hoeller
047 * @since 4.1
048 */
049public abstract class AbstractClientSockJsSession implements WebSocketSession {
050
051        protected final Log logger = LogFactory.getLog(getClass());
052
053        private final TransportRequest request;
054
055        private final WebSocketHandler webSocketHandler;
056
057        private final SettableListenableFuture<WebSocketSession> connectFuture;
058
059        private final Map<String, Object> attributes = new ConcurrentHashMap<String, Object>();
060
061        private volatile State state = State.NEW;
062
063        private volatile CloseStatus closeStatus;
064
065
066        protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler,
067                        SettableListenableFuture<WebSocketSession> connectFuture) {
068
069                Assert.notNull(request, "'request' is required");
070                Assert.notNull(handler, "'handler' is required");
071                Assert.notNull(connectFuture, "'connectFuture' is required");
072                this.request = request;
073                this.webSocketHandler = handler;
074                this.connectFuture = connectFuture;
075        }
076
077
078        @Override
079        public String getId() {
080                return this.request.getSockJsUrlInfo().getSessionId();
081        }
082
083        @Override
084        public URI getUri() {
085                return this.request.getSockJsUrlInfo().getSockJsUrl();
086        }
087
088        @Override
089        public HttpHeaders getHandshakeHeaders() {
090                return this.request.getHandshakeHeaders();
091        }
092
093        @Override
094        public Map<String, Object> getAttributes() {
095                return this.attributes;
096        }
097
098        @Override
099        public Principal getPrincipal() {
100                return this.request.getUser();
101        }
102
103        public SockJsMessageCodec getMessageCodec() {
104                return this.request.getMessageCodec();
105        }
106
107        public WebSocketHandler getWebSocketHandler() {
108                return this.webSocketHandler;
109        }
110
111        /**
112         * Return a timeout cleanup task to invoke if the SockJS sessions is not
113         * fully established within the retransmission timeout period calculated in
114         * {@code SockJsRequest} based on the duration of the initial SockJS "Info"
115         * request.
116         */
117        Runnable getTimeoutTask() {
118                return new Runnable() {
119                        @Override
120                        public void run() {
121                                try {
122                                        closeInternal(new CloseStatus(2007, "Transport timed out"));
123                                }
124                                catch (Throwable ex) {
125                                        if (logger.isWarnEnabled()) {
126                                                logger.warn("Failed to close " + this + " after transport timeout", ex);
127                                        }
128                                }
129                        }
130                };
131        }
132
133        @Override
134        public boolean isOpen() {
135                return (this.state == State.OPEN);
136        }
137
138        public boolean isDisconnected() {
139                return (this.state == State.CLOSING || this.state == State.CLOSED);
140        }
141
142        @Override
143        public final void sendMessage(WebSocketMessage<?> message) throws IOException {
144                if (!(message instanceof TextMessage)) {
145                        throw new IllegalArgumentException(this + " supports text messages only.");
146                }
147                if (this.state != State.OPEN) {
148                        throw new IllegalStateException(this + " is not open: current state " + this.state);
149                }
150
151                String payload = ((TextMessage) message).getPayload();
152                payload = getMessageCodec().encode(payload);
153                payload = payload.substring(1);  // the client-side doesn't need message framing (letter "a")
154
155                TextMessage messageToSend = new TextMessage(payload);
156                if (logger.isTraceEnabled()) {
157                        logger.trace("Sending message " + messageToSend + " in " + this);
158                }
159                sendInternal(messageToSend);
160        }
161
162        protected abstract void sendInternal(TextMessage textMessage) throws IOException;
163
164        @Override
165        public final void close() throws IOException {
166                close(CloseStatus.NORMAL);
167        }
168
169        @Override
170        public final void close(CloseStatus status) throws IOException {
171                if (!isUserSetStatus(status)) {
172                        throw new IllegalArgumentException("Invalid close status: " + status);
173                }
174                if (logger.isDebugEnabled()) {
175                        logger.debug("Closing session with " +  status + " in " + this);
176                }
177                closeInternal(status);
178        }
179
180        private boolean isUserSetStatus(CloseStatus status) {
181                return (status != null && (status.getCode() == 1000 ||
182                                (status.getCode() >= 3000 && status.getCode() <= 4999)));
183        }
184
185        private void silentClose(CloseStatus status) {
186                try {
187                        closeInternal(status);
188                }
189                catch (Throwable ex) {
190                        if (logger.isWarnEnabled()) {
191                                logger.warn("Failed to close " + this, ex);
192                        }
193                }
194        }
195
196        protected void closeInternal(CloseStatus status) throws IOException {
197                if (this.state == null) {
198                        logger.warn("Ignoring close since connect() was never invoked");
199                        return;
200                }
201                if (isDisconnected()) {
202                        if (logger.isDebugEnabled()) {
203                                logger.debug("Ignoring close (already closing or closed): current state " + this.state);
204                        }
205                        return;
206                }
207
208                this.state = State.CLOSING;
209                this.closeStatus = status;
210                disconnect(status);
211        }
212
213        protected abstract void disconnect(CloseStatus status) throws IOException;
214
215        public void handleFrame(String payload) {
216                SockJsFrame frame = new SockJsFrame(payload);
217                switch (frame.getType()) {
218                        case OPEN:
219                                handleOpenFrame();
220                                break;
221                        case HEARTBEAT:
222                                if (logger.isTraceEnabled()) {
223                                        logger.trace("Received heartbeat in " + this);
224                                }
225                                break;
226                        case MESSAGE:
227                                handleMessageFrame(frame);
228                                break;
229                        case CLOSE:
230                                handleCloseFrame(frame);
231                }
232        }
233
234        private void handleOpenFrame() {
235                if (logger.isDebugEnabled()) {
236                        logger.debug("Processing SockJS open frame in " + this);
237                }
238                if (this.state == State.NEW) {
239                        this.state = State.OPEN;
240                        try {
241                                this.webSocketHandler.afterConnectionEstablished(this);
242                                this.connectFuture.set(this);
243                        }
244                        catch (Throwable ex) {
245                                if (logger.isErrorEnabled()) {
246                                        logger.error("WebSocketHandler.afterConnectionEstablished threw exception in " + this, ex);
247                                }
248                        }
249                }
250                else {
251                        if (logger.isDebugEnabled()) {
252                                logger.debug("Open frame received in " + getId() + " but we're not connecting (current state " +
253                                                this.state + "). The server might have been restarted and lost track of the session.");
254                        }
255                        silentClose(new CloseStatus(1006, "Server lost session"));
256                }
257        }
258
259        private void handleMessageFrame(SockJsFrame frame) {
260                if (!isOpen()) {
261                        if (logger.isErrorEnabled()) {
262                                logger.error("Ignoring received message due to state " + this.state + " in " + this);
263                        }
264                        return;
265                }
266
267                String[] messages;
268                try {
269                        messages = getMessageCodec().decode(frame.getFrameData());
270                }
271                catch (IOException ex) {
272                        if (logger.isErrorEnabled()) {
273                                logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex);
274                        }
275                        silentClose(CloseStatus.BAD_DATA);
276                        return;
277                }
278
279                if (logger.isTraceEnabled()) {
280                        logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this);
281                }
282                for (String message : messages) {
283                        if (isOpen()) {
284                                try {
285                                        this.webSocketHandler.handleMessage(this, new TextMessage(message));
286                                }
287                                catch (Throwable ex) {
288                                        logger.error("WebSocketHandler.handleMessage threw an exception on " + frame + " in " + this, ex);
289                                }
290                        }
291                }
292        }
293
294        private void handleCloseFrame(SockJsFrame frame) {
295                CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE;
296                try {
297                        String[] data = getMessageCodec().decode(frame.getFrameData());
298                        if (data.length == 2) {
299                                closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]);
300                        }
301                        if (logger.isDebugEnabled()) {
302                                logger.debug("Processing SockJS close frame with " + closeStatus + " in " + this);
303                        }
304                }
305                catch (IOException ex) {
306                        if (logger.isErrorEnabled()) {
307                                logger.error("Failed to decode data for " + frame + " in " + this, ex);
308                        }
309                }
310                silentClose(closeStatus);
311        }
312
313        public void handleTransportError(Throwable error) {
314                try {
315                        if (logger.isErrorEnabled()) {
316                                logger.error("Transport error in " + this, error);
317                        }
318                        this.webSocketHandler.handleTransportError(this, error);
319                }
320                catch (Throwable ex) {
321                        logger.error("WebSocketHandler.handleTransportError threw an exception", ex);
322                }
323        }
324
325        public void afterTransportClosed(CloseStatus closeStatus) {
326                CloseStatus cs = this.closeStatus;
327                if (cs == null) {
328                        cs = closeStatus;
329                        this.closeStatus = closeStatus;
330                }
331                Assert.state(cs != null, "CloseStatus not available");
332                if (logger.isDebugEnabled()) {
333                        logger.debug("Transport closed with " + cs + " in " + this);
334                }
335
336                this.state = State.CLOSED;
337                try {
338                        this.webSocketHandler.afterConnectionClosed(this, cs);
339                }
340                catch (Throwable ex) {
341                        logger.error("WebSocketHandler.afterConnectionClosed threw an exception", ex);
342                }
343        }
344
345        @Override
346        public String toString() {
347                return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]";
348        }
349
350
351        private enum State { NEW, OPEN, CLOSING, CLOSED }
352
353}