001/* 002 * Copyright 2002-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * https://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.web.socket.server.support; 018 019import java.util.Collection; 020import java.util.Collections; 021import java.util.LinkedHashSet; 022import java.util.Map; 023import java.util.Set; 024 025import org.apache.commons.logging.Log; 026import org.apache.commons.logging.LogFactory; 027 028import org.springframework.http.HttpStatus; 029import org.springframework.http.server.ServerHttpRequest; 030import org.springframework.http.server.ServerHttpResponse; 031import org.springframework.lang.Nullable; 032import org.springframework.util.Assert; 033import org.springframework.web.socket.WebSocketHandler; 034import org.springframework.web.socket.server.HandshakeInterceptor; 035import org.springframework.web.util.WebUtils; 036 037/** 038 * An interceptor to check request {@code Origin} header value against a 039 * collection of allowed origins. 040 * 041 * @author Sebastien Deleuze 042 * @since 4.1.2 043 */ 044public class OriginHandshakeInterceptor implements HandshakeInterceptor { 045 046 protected final Log logger = LogFactory.getLog(getClass()); 047 048 private final Set<String> allowedOrigins = new LinkedHashSet<>(); 049 050 051 /** 052 * Default constructor with only same origin requests allowed. 053 */ 054 public OriginHandshakeInterceptor() { 055 } 056 057 /** 058 * Constructor using the specified allowed origin values. 059 * @see #setAllowedOrigins(Collection) 060 */ 061 public OriginHandshakeInterceptor(Collection<String> allowedOrigins) { 062 setAllowedOrigins(allowedOrigins); 063 } 064 065 066 /** 067 * Configure allowed {@code Origin} header values. This check is mostly 068 * designed for browsers. There is nothing preventing other types of client 069 * to modify the {@code Origin} header value. 070 * <p>Each provided allowed origin must have a scheme, and optionally a port 071 * (e.g. "https://example.org", "https://example.org:9090"). An allowed origin 072 * string may also be "*" in which case all origins are allowed. 073 * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> 074 */ 075 public void setAllowedOrigins(Collection<String> allowedOrigins) { 076 Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); 077 this.allowedOrigins.clear(); 078 this.allowedOrigins.addAll(allowedOrigins); 079 } 080 081 /** 082 * Return the allowed {@code Origin} header values. 083 * @since 4.1.5 084 * @see #setAllowedOrigins 085 */ 086 public Collection<String> getAllowedOrigins() { 087 return Collections.unmodifiableSet(this.allowedOrigins); 088 } 089 090 091 @Override 092 public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, 093 WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { 094 095 if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { 096 response.setStatusCode(HttpStatus.FORBIDDEN); 097 if (logger.isDebugEnabled()) { 098 logger.debug("Handshake request rejected, Origin header value " + 099 request.getHeaders().getOrigin() + " not allowed"); 100 } 101 return false; 102 } 103 return true; 104 } 105 106 @Override 107 public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, 108 WebSocketHandler wsHandler, @Nullable Exception exception) { 109 } 110 111}