001/* 002 * Copyright 2002-2017 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.standard; 018 019import java.net.InetSocketAddress; 020import java.security.Principal; 021import java.util.ArrayList; 022import java.util.List; 023import java.util.Map; 024 025import javax.servlet.ServletContext; 026import javax.servlet.http.HttpServletRequest; 027import javax.servlet.http.HttpServletResponse; 028import javax.websocket.Endpoint; 029import javax.websocket.Extension; 030import javax.websocket.WebSocketContainer; 031import javax.websocket.server.ServerContainer; 032 033import org.apache.commons.logging.Log; 034import org.apache.commons.logging.LogFactory; 035 036import org.springframework.http.HttpHeaders; 037import org.springframework.http.server.ServerHttpRequest; 038import org.springframework.http.server.ServerHttpResponse; 039import org.springframework.http.server.ServletServerHttpRequest; 040import org.springframework.http.server.ServletServerHttpResponse; 041import org.springframework.lang.Nullable; 042import org.springframework.util.Assert; 043import org.springframework.web.socket.WebSocketExtension; 044import org.springframework.web.socket.WebSocketHandler; 045import org.springframework.web.socket.adapter.standard.StandardToWebSocketExtensionAdapter; 046import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; 047import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; 048import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; 049import org.springframework.web.socket.server.HandshakeFailureException; 050import org.springframework.web.socket.server.RequestUpgradeStrategy; 051 052/** 053 * A base class for {@link RequestUpgradeStrategy} implementations that build 054 * on the standard WebSocket API for Java (JSR-356). 055 * 056 * @author Rossen Stoyanchev 057 * @since 4.0 058 */ 059public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeStrategy { 060 061 protected final Log logger = LogFactory.getLog(getClass()); 062 063 @Nullable 064 private volatile List<WebSocketExtension> extensions; 065 066 067 protected ServerContainer getContainer(HttpServletRequest request) { 068 ServletContext servletContext = request.getServletContext(); 069 String attrName = "javax.websocket.server.ServerContainer"; 070 ServerContainer container = (ServerContainer) servletContext.getAttribute(attrName); 071 Assert.notNull(container, "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " + 072 "Are you running in a Servlet container that supports JSR-356?"); 073 return container; 074 } 075 076 protected final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { 077 Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required"); 078 return ((ServletServerHttpRequest) request).getServletRequest(); 079 } 080 081 protected final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { 082 Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required"); 083 return ((ServletServerHttpResponse) response).getServletResponse(); 084 } 085 086 087 @Override 088 public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) { 089 List<WebSocketExtension> extensions = this.extensions; 090 if (extensions == null) { 091 HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); 092 extensions = getInstalledExtensions(getContainer(servletRequest)); 093 this.extensions = extensions; 094 } 095 return extensions; 096 } 097 098 protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) { 099 List<WebSocketExtension> result = new ArrayList<>(); 100 for (Extension extension : container.getInstalledExtensions()) { 101 result.add(new StandardToWebSocketExtensionAdapter(extension)); 102 } 103 return result; 104 } 105 106 107 @Override 108 public void upgrade(ServerHttpRequest request, ServerHttpResponse response, 109 @Nullable String selectedProtocol, List<WebSocketExtension> selectedExtensions, 110 @Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs) 111 throws HandshakeFailureException { 112 113 HttpHeaders headers = request.getHeaders(); 114 InetSocketAddress localAddr = null; 115 try { 116 localAddr = request.getLocalAddress(); 117 } 118 catch (Exception ex) { 119 // Ignore 120 } 121 InetSocketAddress remoteAddr = null; 122 try { 123 remoteAddr = request.getRemoteAddress(); 124 } 125 catch (Exception ex) { 126 // Ignore 127 } 128 129 StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user); 130 StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session); 131 132 List<Extension> extensions = new ArrayList<>(); 133 for (WebSocketExtension extension : selectedExtensions) { 134 extensions.add(new WebSocketToStandardExtensionAdapter(extension)); 135 } 136 137 upgradeInternal(request, response, selectedProtocol, extensions, endpoint); 138 } 139 140 protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, 141 @Nullable String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint) 142 throws HandshakeFailureException; 143 144}