001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.net.InetSocketAddress; 020import java.security.Principal; 021import java.util.ArrayList; 022import java.util.List; 023import java.util.Map; 024import javax.servlet.ServletContext; 025import javax.servlet.http.HttpServletRequest; 026import javax.servlet.http.HttpServletResponse; 027import javax.websocket.Endpoint; 028import javax.websocket.Extension; 029import javax.websocket.WebSocketContainer; 030import javax.websocket.server.ServerContainer; 031 032import org.apache.commons.logging.Log; 033import org.apache.commons.logging.LogFactory; 034 035import org.springframework.http.HttpHeaders; 036import org.springframework.http.server.ServerHttpRequest; 037import org.springframework.http.server.ServerHttpResponse; 038import org.springframework.http.server.ServletServerHttpRequest; 039import org.springframework.http.server.ServletServerHttpResponse; 040import org.springframework.util.Assert; 041import org.springframework.web.socket.WebSocketExtension; 042import org.springframework.web.socket.WebSocketHandler; 043import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter; 044import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; 045import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; 046import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; 047import org.springframework.web.socket.server.HandshakeFailureException; 048import org.springframework.web.socket.server.RequestUpgradeStrategy; 049 050/** 051 * A base class for {@link RequestUpgradeStrategy} implementations that build 052 * on the standard WebSocket API for Java (JSR-356). 053 * 054 * @author Rossen Stoyanchev 055 * @since 4.0 056 */ 057public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy { 058 059 protected final Log logger = LogFactory.getLog(getClass()); 060 061 private volatile List<WebSocketExtension> extensions; 062 063 064 protected ServerContainer getContainer(HttpServletRequest request) { 065 ServletContext servletContext = request.getServletContext(); 066 String attrName = "javax.websocket.server.ServerContainer"; 067 ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName); 068 Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " + 069 "Are you running in a Servlet container that supports JSR-356?"); 070 return container; 071 } 072 073 protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { 074 Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required"); 075 return ((ServletServerHttpRequest) request).getServletRequest(); 076 } 077 078 protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { 079 Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required"); 080 return ((ServletServerHttpResponse) response).getServletResponse(); 081 } 082 083 084 @Override 085 public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) { 086 if (this.extensions == null) { 087 HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); 088 this.extensions = getInstalledExtensions(getContainer(servletRequest)); 089 } 090 return this.extensions; 091 } 092 093 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) { 094 List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(); 095 for (Extension extension : container.getInstalledExtensions()) { 096 result.add(new StandardToWebSocketExtensionAdapter(extension)); 097 } 098 return result; 099 } 100 101 102 @Override 103 public void upgrade(ServerHttpRequest request, ServerHttpResponse response, 104 String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user, 105 WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException { 106 107 HttpHeaders headers = request.getHeaders(); 108 InetSocketAddress localAddr = null; 109 try { 110 localAddr = request.getLocalAddress(); 111 } 112 catch (Exception ex) { 113 // Ignore 114 } 115 InetSocketAddress remoteAddr = null; 116 try { 117 remoteAddr = request.getRemoteAddress(); 118 } 119 catch (Exception ex) { 120 // Ignore 121 } 122 123 StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user); 124 StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session); 125 126 List<Extension> extensions = new ArrayList<Extension>(); 127 for (WebSocketExtension extension : selectedExtensions) { 128 extensions.add(new WebSocketToStandardExtensionAdapter(extension)); 129 } 130 131 upgradeInternal(request, response, selectedProtocol, extensions, endpoint); 132 } 133 134 protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 135 String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint) 136 throws HandshakeFailureException; 137 138}