001/*
002 * Copyright 2002-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.reactive.socket.adapter;
018
019import java.nio.charset.StandardCharsets;
020import java.util.Map;
021import java.util.concurrent.ConcurrentHashMap;
022import java.util.function.Function;
023
024import org.apache.commons.logging.Log;
025import org.apache.commons.logging.LogFactory;
026import org.reactivestreams.Publisher;
027import reactor.core.publisher.Flux;
028import reactor.core.publisher.Mono;
029
030import org.springframework.core.io.buffer.DataBuffer;
031import org.springframework.core.io.buffer.DataBufferFactory;
032import org.springframework.util.Assert;
033import org.springframework.web.reactive.socket.HandshakeInfo;
034import org.springframework.web.reactive.socket.WebSocketMessage;
035import org.springframework.web.reactive.socket.WebSocketSession;
036
037/**
038 * Convenient base class for {@link WebSocketSession} implementations that
039 * holds common fields and exposes accessors. Also implements the
040 * {@code WebSocketMessage} factory methods.
041 *
042 * @author Rossen Stoyanchev
043 * @since 5.0
044 * @param <T> the native delegate type
045 */
046public abstract class AbstractWebSocketSession<T> implements WebSocketSession {
047
048        protected final Log logger = LogFactory.getLog(getClass());
049
050        private final T delegate;
051
052        private final String id;
053
054        private final HandshakeInfo handshakeInfo;
055
056        private final DataBufferFactory bufferFactory;
057
058        private final Map<String, Object> attributes = new ConcurrentHashMap<>();
059
060        private final String logPrefix;
061
062
063        /**
064         * Create a new WebSocket session.
065         */
066        protected AbstractWebSocketSession(T delegate, String id, HandshakeInfo info, DataBufferFactory bufferFactory) {
067                Assert.notNull(delegate, "Native session is required.");
068                Assert.notNull(id, "Session id is required.");
069                Assert.notNull(info, "HandshakeInfo is required.");
070                Assert.notNull(bufferFactory, "DataBuffer factory is required.");
071
072                this.delegate = delegate;
073                this.id = id;
074                this.handshakeInfo = info;
075                this.bufferFactory = bufferFactory;
076                this.attributes.putAll(info.getAttributes());
077                this.logPrefix = initLogPrefix(info, id);
078
079                if (logger.isDebugEnabled()) {
080                        logger.debug(getLogPrefix() + "Session id \"" + getId() + "\" for " + getHandshakeInfo().getUri());
081                }
082        }
083
084        private static String initLogPrefix(HandshakeInfo info, String id) {
085                return info.getLogPrefix() != null ? info.getLogPrefix() : "[" + id + "] ";
086        }
087
088
089        protected T getDelegate() {
090                return this.delegate;
091        }
092
093        @Override
094        public String getId() {
095                return this.id;
096        }
097
098        @Override
099        public HandshakeInfo getHandshakeInfo() {
100                return this.handshakeInfo;
101        }
102
103        @Override
104        public DataBufferFactory bufferFactory() {
105                return this.bufferFactory;
106        }
107
108        @Override
109        public Map<String, Object> getAttributes() {
110                return this.attributes;
111        }
112
113        protected String getLogPrefix() {
114                return this.logPrefix;
115        }
116
117
118        @Override
119        public abstract Flux<WebSocketMessage> receive();
120
121        @Override
122        public abstract Mono<Void> send(Publisher<WebSocketMessage> messages);
123
124
125        // WebSocketMessage factory methods
126
127        @Override
128        public WebSocketMessage textMessage(String payload) {
129                byte[] bytes = payload.getBytes(StandardCharsets.UTF_8);
130                DataBuffer buffer = bufferFactory().wrap(bytes);
131                return new WebSocketMessage(WebSocketMessage.Type.TEXT, buffer);
132        }
133
134        @Override
135        public WebSocketMessage binaryMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) {
136                DataBuffer payload = payloadFactory.apply(bufferFactory());
137                return new WebSocketMessage(WebSocketMessage.Type.BINARY, payload);
138        }
139
140        @Override
141        public WebSocketMessage pingMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) {
142                DataBuffer payload = payloadFactory.apply(bufferFactory());
143                return new WebSocketMessage(WebSocketMessage.Type.PING, payload);
144        }
145
146        @Override
147        public WebSocketMessage pongMessage(Function<DataBufferFactory, DataBuffer> payloadFactory) {
148                DataBuffer payload = payloadFactory.apply(bufferFactory());
149                return new WebSocketMessage(WebSocketMessage.Type.PONG, payload);
150        }
151
152
153        @Override
154        public String toString() {
155                return getClass().getSimpleName() + "[id=" + getId() + ", uri=" + getHandshakeInfo().getUri() + "]";
156        }
157
158}