001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.support; 018 019import java.util.Collection; 020import java.util.Collections; 021import java.util.LinkedHashSet; 022import java.util.Map; 023import java.util.Set; 024 025import org.apache.commons.logging.Log; 026import org.apache.commons.logging.LogFactory; 027 028import org.springframework.http.HttpStatus; 029import org.springframework.http.server.ServerHttpRequest; 030import org.springframework.http.server.ServerHttpResponse; 031import org.springframework.util.Assert; 032import org.springframework.web.socket.WebSocketHandler; 033import org.springframework.web.socket.server.HandshakeInterceptor; 034import org.springframework.web.util.WebUtils; 035 036/** 037 * An interceptor to check request {@code Origin} header value against a 038 * collection of allowed origins. 039 * 040 * @author Sebastien Deleuze 041 * @since 4.1.2 042 */ 043public class OriginHandshakeInterceptor implements HandshakeInterceptor { 044 045 protected final Log logger = LogFactory.getLog(getClass()); 046 047 private final Set<String> allowedOrigins = new LinkedHashSet<String>(); 048 049 050 /** 051 * Default constructor with only same origin requests allowed. 052 */ 053 public OriginHandshakeInterceptor() { 054 } 055 056 /** 057 * Constructor using the specified allowed origin values. 058 * @see #setAllowedOrigins(Collection) 059 */ 060 public OriginHandshakeInterceptor(Collection<String> allowedOrigins) { 061 setAllowedOrigins(allowedOrigins); 062 } 063 064 065 /** 066 * Configure allowed {@code Origin} header values. This check is mostly 067 * designed for browsers. There is nothing preventing other types of client 068 * to modify the {@code Origin} header value. 069 * <p>Each provided allowed origin must have a scheme, and optionally a port 070 * (e.g. "https://example.org", "https://example.org:9090"). An allowed origin 071 * string may also be "*" in which case all origins are allowed. 072 * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> 073 */ 074 public void setAllowedOrigins(Collection<String> allowedOrigins) { 075 Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); 076 this.allowedOrigins.clear(); 077 this.allowedOrigins.addAll(allowedOrigins); 078 } 079 080 /** 081 * Return the allowed {@code Origin} header values. 082 * @since 4.1.5 083 * @see #setAllowedOrigins 084 */ 085 public Collection<String> getAllowedOrigins() { 086 return Collections.unmodifiableSet(this.allowedOrigins); 087 } 088 089 090 @Override 091 public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, 092 WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { 093 094 if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { 095 response.setStatusCode(HttpStatus.FORBIDDEN); 096 if (logger.isDebugEnabled()) { 097 logger.debug("Handshake request rejected, Origin header value " + 098 request.getHeaders().getOrigin() + " not allowed"); 099 } 100 return false; 101 } 102 return true; 103 } 104 105 @Override 106 public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, 107 WebSocketHandler wsHandler, Exception exception) { 108 } 109 110}