001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.adapter.standard;
018
019import java.io.IOException;
020import java.net.InetSocketAddress;
021import java.net.URI;
022import java.security.Principal;
023import java.util.ArrayList;
024import java.util.Collections;
025import java.util.List;
026import java.util.Map;
027import javax.websocket.CloseReason;
028import javax.websocket.CloseReason.CloseCodes;
029import javax.websocket.Extension;
030import javax.websocket.Session;
031
032import org.springframework.http.HttpHeaders;
033import org.springframework.util.CollectionUtils;
034import org.springframework.web.socket.BinaryMessage;
035import org.springframework.web.socket.CloseStatus;
036import org.springframework.web.socket.PingMessage;
037import org.springframework.web.socket.PongMessage;
038import org.springframework.web.socket.TextMessage;
039import org.springframework.web.socket.WebSocketExtension;
040import org.springframework.web.socket.WebSocketSession;
041import org.springframework.web.socket.adapter.AbstractWebSocketSession;
042
043/**
044 * A {@link WebSocketSession} for use with the standard WebSocket for Java API.
045 *
046 * @author Rossen Stoyanchev
047 * @since 4.0
048 */
049public class StandardWebSocketSession extends AbstractWebSocketSession<Session> {
050
051        private String id;
052
053        private URI uri;
054
055        private final HttpHeaders handshakeHeaders;
056
057        private String acceptedProtocol;
058
059        private List<WebSocketExtension> extensions;
060
061        private Principal user;
062
063        private final InetSocketAddress localAddress;
064
065        private final InetSocketAddress remoteAddress;
066
067
068        /**
069         * Constructor for a standard WebSocket session.
070         * @param headers the headers of the handshake request
071         * @param attributes attributes from the HTTP handshake to associate with the WebSocket
072         * session; the provided attributes are copied, the original map is not used.
073         * @param localAddress the address on which the request was received
074         * @param remoteAddress the address of the remote client
075         */
076        public StandardWebSocketSession(HttpHeaders headers, Map<String, Object> attributes,
077                        InetSocketAddress localAddress, InetSocketAddress remoteAddress) {
078
079                this(headers, attributes, localAddress, remoteAddress, null);
080        }
081
082        /**
083         * Constructor that associates a user with the WebSocket session.
084         * @param headers the headers of the handshake request
085         * @param attributes attributes from the HTTP handshake to associate with the WebSocket session
086         * @param localAddress the address on which the request was received
087         * @param remoteAddress the address of the remote client
088         * @param user the user associated with the session; if {@code null} we'll
089         *      fallback on the user available in the underlying WebSocket session
090         */
091        public StandardWebSocketSession(HttpHeaders headers, Map<String, Object> attributes,
092                        InetSocketAddress localAddress, InetSocketAddress remoteAddress, Principal user) {
093
094                super(attributes);
095                headers = (headers != null) ? headers : new HttpHeaders();
096                this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(headers);
097                this.user = user;
098                this.localAddress = localAddress;
099                this.remoteAddress = remoteAddress;
100        }
101
102
103        @Override
104        public String getId() {
105                checkNativeSessionInitialized();
106                return this.id;
107        }
108
109        @Override
110        public URI getUri() {
111                checkNativeSessionInitialized();
112                return this.uri;
113        }
114
115        @Override
116        public HttpHeaders getHandshakeHeaders() {
117                return this.handshakeHeaders;
118        }
119
120        @Override
121        public String getAcceptedProtocol() {
122                checkNativeSessionInitialized();
123                return this.acceptedProtocol;
124        }
125
126        @Override
127        public List<WebSocketExtension> getExtensions() {
128                checkNativeSessionInitialized();
129                return this.extensions;
130        }
131
132        public Principal getPrincipal() {
133                return this.user;
134        }
135
136        @Override
137        public InetSocketAddress getLocalAddress() {
138                return this.localAddress;
139        }
140
141        @Override
142        public InetSocketAddress getRemoteAddress() {
143                return this.remoteAddress;
144        }
145
146        @Override
147        public void setTextMessageSizeLimit(int messageSizeLimit) {
148                checkNativeSessionInitialized();
149                getNativeSession().setMaxTextMessageBufferSize(messageSizeLimit);
150        }
151
152        @Override
153        public int getTextMessageSizeLimit() {
154                checkNativeSessionInitialized();
155                return getNativeSession().getMaxTextMessageBufferSize();
156        }
157
158        @Override
159        public void setBinaryMessageSizeLimit(int messageSizeLimit) {
160                checkNativeSessionInitialized();
161                getNativeSession().setMaxBinaryMessageBufferSize(messageSizeLimit);
162        }
163
164        @Override
165        public int getBinaryMessageSizeLimit() {
166                checkNativeSessionInitialized();
167                return getNativeSession().getMaxBinaryMessageBufferSize();
168        }
169
170        @Override
171        public boolean isOpen() {
172                return (getNativeSession() != null && getNativeSession().isOpen());
173        }
174
175        @Override
176        public void initializeNativeSession(Session session) {
177                super.initializeNativeSession(session);
178
179                this.id = session.getId();
180                this.uri = session.getRequestURI();
181
182                this.acceptedProtocol = session.getNegotiatedSubprotocol();
183
184                List<Extension> standardExtensions = getNativeSession().getNegotiatedExtensions();
185                if (!CollectionUtils.isEmpty(standardExtensions)) {
186                        this.extensions = new ArrayList<WebSocketExtension>(standardExtensions.size());
187                        for (Extension standardExtension : standardExtensions) {
188                                this.extensions.add(new StandardToWebSocketExtensionAdapter(standardExtension));
189                        }
190                        this.extensions = Collections.unmodifiableList(this.extensions);
191                }
192                else {
193                        this.extensions = Collections.emptyList();
194                }
195
196                if (this.user == null) {
197                        this.user = session.getUserPrincipal();
198                }
199        }
200
201        @Override
202        protected void sendTextMessage(TextMessage message) throws IOException {
203                getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast());
204        }
205
206        @Override
207        protected void sendBinaryMessage(BinaryMessage message) throws IOException {
208                getNativeSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast());
209        }
210
211        @Override
212        protected void sendPingMessage(PingMessage message) throws IOException {
213                getNativeSession().getBasicRemote().sendPing(message.getPayload());
214        }
215
216        @Override
217        protected void sendPongMessage(PongMessage message) throws IOException {
218                getNativeSession().getBasicRemote().sendPong(message.getPayload());
219        }
220
221        @Override
222        protected void closeInternal(CloseStatus status) throws IOException {
223                getNativeSession().close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason()));
224        }
225
226}