001/*
002 * Copyright 2002-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      https://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.web.socket.server.standard;
018
019import java.net.InetSocketAddress;
020import java.security.Principal;
021import java.util.ArrayList;
022import java.util.List;
023import java.util.Map;
024import javax.servlet.ServletContext;
025import javax.servlet.http.HttpServletRequest;
026import javax.servlet.http.HttpServletResponse;
027import javax.websocket.Endpoint;
028import javax.websocket.Extension;
029import javax.websocket.WebSocketContainer;
030import javax.websocket.server.ServerContainer;
031
032import org.apache.commons.logging.Log;
033import org.apache.commons.logging.LogFactory;
034
035import org.springframework.http.HttpHeaders;
036import org.springframework.http.server.ServerHttpRequest;
037import org.springframework.http.server.ServerHttpResponse;
038import org.springframework.http.server.ServletServerHttpRequest;
039import org.springframework.http.server.ServletServerHttpResponse;
040import org.springframework.util.Assert;
041import org.springframework.web.socket.WebSocketExtension;
042import org.springframework.web.socket.WebSocketHandler;
043import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter;
044import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
045import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
046import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
047import org.springframework.web.socket.server.HandshakeFailureException;
048import org.springframework.web.socket.server.RequestUpgradeStrategy;
049
050/**
051 * A base class for {@link RequestUpgradeStrategy} implementations that build
052 * on the standard WebSocket API for Java (JSR-356).
053 *
054 * @author Rossen Stoyanchev
055 * @since 4.0
056 */
057public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy {
058
059        protected final Log logger = LogFactory.getLog(getClass());
060
061        private volatile List<WebSocketExtension> extensions;
062
063
064        protected ServerContainer getContainer(HttpServletRequest request) {
065                ServletContext servletContext = request.getServletContext();
066                String attrName = "javax.websocket.server.ServerContainer";
067                ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName);
068                Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
069                                "Are you running in a Servlet container that supports JSR-356?");
070                return container;
071        }
072
073        protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
074                Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required");
075                return ((ServletServerHttpRequest) request).getServletRequest();
076        }
077
078        protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
079                Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required");
080                return ((ServletServerHttpResponse) response).getServletResponse();
081        }
082
083
084        @Override
085        public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
086                if (this.extensions == null) {
087                        HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
088                        this.extensions = getInstalledExtensions(getContainer(servletRequest));
089                }
090                return this.extensions;
091        }
092
093        protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
094                List<WebSocketExtension> result = new ArrayList<WebSocketExtension>();
095                for (Extension extension : container.getInstalledExtensions()) {
096                        result.add(new StandardToWebSocketExtensionAdapter(extension));
097                }
098                return result;
099        }
100
101
102        @Override
103        public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
104                        String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
105                        WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException {
106
107                HttpHeaders headers = request.getHeaders();
108                InetSocketAddress localAddr = null;
109                try {
110                        localAddr = request.getLocalAddress();
111                }
112                catch (Exception ex) {
113                        // Ignore
114                }
115                InetSocketAddress remoteAddr = null;
116                try {
117                        remoteAddr = request.getRemoteAddress();
118                }
119                catch (Exception ex) {
120                        // Ignore
121                }
122
123                StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user);
124                StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session);
125
126                List<Extension> extensions = new ArrayList<Extension>();
127                for (WebSocketExtension extension : selectedExtensions) {
128                        extensions.add(new WebSocketToStandardExtensionAdapter(extension));
129                }
130
131                upgradeInternal(request, response, selectedProtocol, extensions, endpoint);
132        }
133
134        protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
135                        String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint)
136                        throws HandshakeFailureException;
137
138}