001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.net.InetSocketAddress;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.List;
023import java.util.Map;
024
025import javax.servlet.ServletContext;
026import javax.servlet.http.HttpServletRequest;
027import javax.servlet.http.HttpServletResponse;
028import javax.websocket.Endpoint;
029import javax.websocket.Extension;
030import javax.websocket.WebSocketContainer;
031import javax.websocket.server.ServerContainer;
032
033import org.apache.commons.logging.Log;
034import org.apache.commons.logging.LogFactory;
035
036import org.springframework.http.HttpHeaders;
037import org.springframework.http.server.ServerHttpRequest;
038import org.springframework.http.server.ServerHttpResponse;
039import org.springframework.http.server.ServletServerHttpRequest;
040import org.springframework.http.server.ServletServerHttpResponse;
041import org.springframework.lang.Nullable;
042import org.springframework.util.Assert;
043import org.springframework.web.socket.WebSocketExtension;
044import org.springframework.web.socket.WebSocketHandler;
045import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter;
046import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
047import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
048import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
049import org.springframework.web.socket.server.HandshakeFailureException;
050import org.springframework.web.socket.server.RequestUpgradeStrategy;
051
052/**
053 * A base class for {@link RequestUpgradeStrategy} implementations that build
054 * on the standard WebSocket API for Java (JSR-356).
055 *
056 * @author Rossen Stoyanchev
057 * @since 4.0
058 */
059public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy {
060
061        protected final Log logger = LogFactory.getLog(getClass());
062
063        @Nullable
064        private volatile List<WebSocketExtension> extensions;
065
066
067        protected ServerContainer getContainer(HttpServletRequest request) {
068                ServletContext servletContext = request.getServletContext();
069                String attrName = "javax.websocket.server.ServerContainer";
070                ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName);
071                Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
072                                "Are you running in a Servlet container that supports JSR-356?");
073                return container;
074        }
075
076        protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
077                Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required");
078                return ((ServletServerHttpRequest) request).getServletRequest();
079        }
080
081        protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
082                Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required");
083                return ((ServletServerHttpResponse) response).getServletResponse();
084        }
085
086
087        @Override
088        public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
089                List<WebSocketExtension> extensions = this.extensions;
090                if (extensions == null) {
091                        HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
092                        extensions = getInstalledExtensions(getContainer(servletRequest));
093                        this.extensions = extensions;
094                }
095                return extensions;
096        }
097
098        protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
099                List<WebSocketExtension> result = new ArrayList<>();
100                for (Extension extension : container.getInstalledExtensions()) {
101                        result.add(new StandardToWebSocketExtensionAdapter(extension));
102                }
103                return result;
104        }
105
106
107        @Override
108        public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
109                        @Nullable String selectedProtocol, List<WebSocketExtension> selectedExtensions,
110                        @Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs)
111                        throws HandshakeFailureException {
112
113                HttpHeaders headers = request.getHeaders();
114                InetSocketAddress localAddr = null;
115                try {
116                        localAddr = request.getLocalAddress();
117                }
118                catch (Exception ex) {
119                        // Ignore
120                }
121                InetSocketAddress remoteAddr = null;
122                try {
123                        remoteAddr = request.getRemoteAddress();
124                }
125                catch (Exception ex) {
126                        // Ignore
127                }
128
129                StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user);
130                StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session);
131
132                List<Extension> extensions = new ArrayList<>();
133                for (WebSocketExtension extension : selectedExtensions) {
134                        extensions.add(new WebSocketToStandardExtensionAdapter(extension));
135                }
136
137                upgradeInternal(request, response, selectedProtocol, extensions, endpoint);
138        }
139
140        protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
141                        @Nullable String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint)
142                        throws HandshakeFailureException;
143
144}