001/* 002 * Copyright 2012-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.boot.web.servlet; 018 019import java.util.AbstractCollection; 020import java.util.ArrayList; 021import java.util.Arrays; 022import java.util.Collections; 023import java.util.EventListener; 024import java.util.HashSet; 025import java.util.Iterator; 026import java.util.LinkedHashMap; 027import java.util.List; 028import java.util.Map; 029import java.util.Map.Entry; 030import java.util.Set; 031import java.util.stream.Collectors; 032 033import javax.servlet.Filter; 034import javax.servlet.MultipartConfigElement; 035import javax.servlet.Servlet; 036 037import org.apache.commons.logging.Log; 038import org.apache.commons.logging.LogFactory; 039 040import org.springframework.aop.scope.ScopedProxyUtils; 041import org.springframework.beans.factory.ListableBeanFactory; 042import org.springframework.beans.factory.support.BeanDefinitionRegistry; 043import org.springframework.core.annotation.AnnotationAwareOrderComparator; 044import org.springframework.util.LinkedMultiValueMap; 045import org.springframework.util.MultiValueMap; 046 047/** 048 * A collection {@link ServletContextInitializer}s obtained from a 049 * {@link ListableBeanFactory}. Includes all {@link ServletContextInitializer} beans and 050 * also adapts {@link Servlet}, {@link Filter} and certain {@link EventListener} beans. 051 * <p> 052 * Items are sorted so that adapted beans are top ({@link Servlet}, {@link Filter} then 053 * {@link EventListener}) and direct {@link ServletContextInitializer} beans are at the 054 * end. Further sorting is applied within these groups using the 055 * {@link AnnotationAwareOrderComparator}. 056 * 057 * @author Dave Syer 058 * @author Phillip Webb 059 * @author Brian Clozel 060 * @since 1.4.0 061 */ 062public class ServletContextInitializerBeans 063 extends AbstractCollection<ServletContextInitializer> { 064 065 private static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet"; 066 067 private static final Log logger = LogFactory 068 .getLog(ServletContextInitializerBeans.class); 069 070 /** 071 * Seen bean instances or bean names. 072 */ 073 private final Set<Object> seen = new HashSet<>(); 074 075 private final MultiValueMap<Class<?>, ServletContextInitializer> initializers; 076 077 private final List<Class<? extends ServletContextInitializer>> initializerTypes; 078 079 private List<ServletContextInitializer> sortedList; 080 081 @SafeVarargs 082 public ServletContextInitializerBeans(ListableBeanFactory beanFactory, 083 Class<? extends ServletContextInitializer>... initializerTypes) { 084 this.initializers = new LinkedMultiValueMap<>(); 085 this.initializerTypes = (initializerTypes.length != 0) 086 ? Arrays.asList(initializerTypes) 087 : Collections.singletonList(ServletContextInitializer.class); 088 addServletContextInitializerBeans(beanFactory); 089 addAdaptableBeans(beanFactory); 090 List<ServletContextInitializer> sortedInitializers = this.initializers.values() 091 .stream() 092 .flatMap((value) -> value.stream() 093 .sorted(AnnotationAwareOrderComparator.INSTANCE)) 094 .collect(Collectors.toList()); 095 this.sortedList = Collections.unmodifiableList(sortedInitializers); 096 logMappings(this.initializers); 097 } 098 099 private void addServletContextInitializerBeans(ListableBeanFactory beanFactory) { 100 for (Class<? extends ServletContextInitializer> initializerType : this.initializerTypes) { 101 for (Entry<String, ? extends ServletContextInitializer> initializerBean : getOrderedBeansOfType( 102 beanFactory, initializerType)) { 103 addServletContextInitializerBean(initializerBean.getKey(), 104 initializerBean.getValue(), beanFactory); 105 } 106 } 107 } 108 109 private void addServletContextInitializerBean(String beanName, 110 ServletContextInitializer initializer, ListableBeanFactory beanFactory) { 111 if (initializer instanceof ServletRegistrationBean) { 112 Servlet source = ((ServletRegistrationBean<?>) initializer).getServlet(); 113 addServletContextInitializerBean(Servlet.class, beanName, initializer, 114 beanFactory, source); 115 } 116 else if (initializer instanceof FilterRegistrationBean) { 117 Filter source = ((FilterRegistrationBean<?>) initializer).getFilter(); 118 addServletContextInitializerBean(Filter.class, beanName, initializer, 119 beanFactory, source); 120 } 121 else if (initializer instanceof DelegatingFilterProxyRegistrationBean) { 122 String source = ((DelegatingFilterProxyRegistrationBean) initializer) 123 .getTargetBeanName(); 124 addServletContextInitializerBean(Filter.class, beanName, initializer, 125 beanFactory, source); 126 } 127 else if (initializer instanceof ServletListenerRegistrationBean) { 128 EventListener source = ((ServletListenerRegistrationBean<?>) initializer) 129 .getListener(); 130 addServletContextInitializerBean(EventListener.class, beanName, initializer, 131 beanFactory, source); 132 } 133 else { 134 addServletContextInitializerBean(ServletContextInitializer.class, beanName, 135 initializer, beanFactory, initializer); 136 } 137 } 138 139 private void addServletContextInitializerBean(Class<?> type, String beanName, 140 ServletContextInitializer initializer, ListableBeanFactory beanFactory, 141 Object source) { 142 this.initializers.add(type, initializer); 143 if (source != null) { 144 // Mark the underlying source as seen in case it wraps an existing bean 145 this.seen.add(source); 146 } 147 if (logger.isTraceEnabled()) { 148 String resourceDescription = getResourceDescription(beanName, beanFactory); 149 int order = getOrder(initializer); 150 logger.trace("Added existing " + type.getSimpleName() + " initializer bean '" 151 + beanName + "'; order=" + order + ", resource=" 152 + resourceDescription); 153 } 154 } 155 156 private String getResourceDescription(String beanName, 157 ListableBeanFactory beanFactory) { 158 if (beanFactory instanceof BeanDefinitionRegistry) { 159 BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory; 160 return registry.getBeanDefinition(beanName).getResourceDescription(); 161 } 162 return "unknown"; 163 } 164 165 @SuppressWarnings("unchecked") 166 protected void addAdaptableBeans(ListableBeanFactory beanFactory) { 167 MultipartConfigElement multipartConfig = getMultipartConfig(beanFactory); 168 addAsRegistrationBean(beanFactory, Servlet.class, 169 new ServletRegistrationBeanAdapter(multipartConfig)); 170 addAsRegistrationBean(beanFactory, Filter.class, 171 new FilterRegistrationBeanAdapter()); 172 for (Class<?> listenerType : ServletListenerRegistrationBean 173 .getSupportedTypes()) { 174 addAsRegistrationBean(beanFactory, EventListener.class, 175 (Class<EventListener>) listenerType, 176 new ServletListenerRegistrationBeanAdapter()); 177 } 178 } 179 180 private MultipartConfigElement getMultipartConfig(ListableBeanFactory beanFactory) { 181 List<Entry<String, MultipartConfigElement>> beans = getOrderedBeansOfType( 182 beanFactory, MultipartConfigElement.class); 183 return beans.isEmpty() ? null : beans.get(0).getValue(); 184 } 185 186 protected <T> void addAsRegistrationBean(ListableBeanFactory beanFactory, 187 Class<T> type, RegistrationBeanAdapter<T> adapter) { 188 addAsRegistrationBean(beanFactory, type, type, adapter); 189 } 190 191 private <T, B extends T> void addAsRegistrationBean(ListableBeanFactory beanFactory, 192 Class<T> type, Class<B> beanType, RegistrationBeanAdapter<T> adapter) { 193 List<Map.Entry<String, B>> entries = getOrderedBeansOfType(beanFactory, beanType, 194 this.seen); 195 for (Entry<String, B> entry : entries) { 196 String beanName = entry.getKey(); 197 B bean = entry.getValue(); 198 if (this.seen.add(bean)) { 199 // One that we haven't already seen 200 RegistrationBean registration = adapter.createRegistrationBean(beanName, 201 bean, entries.size()); 202 int order = getOrder(bean); 203 registration.setOrder(order); 204 this.initializers.add(type, registration); 205 if (logger.isTraceEnabled()) { 206 logger.trace( 207 "Created " + type.getSimpleName() + " initializer for bean '" 208 + beanName + "'; order=" + order + ", resource=" 209 + getResourceDescription(beanName, beanFactory)); 210 } 211 } 212 } 213 } 214 215 private int getOrder(Object value) { 216 return new AnnotationAwareOrderComparator() { 217 @Override 218 public int getOrder(Object obj) { 219 return super.getOrder(obj); 220 } 221 }.getOrder(value); 222 } 223 224 private <T> List<Entry<String, T>> getOrderedBeansOfType( 225 ListableBeanFactory beanFactory, Class<T> type) { 226 return getOrderedBeansOfType(beanFactory, type, Collections.emptySet()); 227 } 228 229 private <T> List<Entry<String, T>> getOrderedBeansOfType( 230 ListableBeanFactory beanFactory, Class<T> type, Set<?> excludes) { 231 String[] names = beanFactory.getBeanNamesForType(type, true, false); 232 Map<String, T> map = new LinkedHashMap<>(); 233 for (String name : names) { 234 if (!excludes.contains(name) && !ScopedProxyUtils.isScopedTarget(name)) { 235 T bean = beanFactory.getBean(name, type); 236 if (!excludes.contains(bean)) { 237 map.put(name, bean); 238 } 239 } 240 } 241 List<Entry<String, T>> beans = new ArrayList<>(); 242 beans.addAll(map.entrySet()); 243 beans.sort((o1, o2) -> AnnotationAwareOrderComparator.INSTANCE 244 .compare(o1.getValue(), o2.getValue())); 245 return beans; 246 } 247 248 private void logMappings( 249 MultiValueMap<Class<?>, ServletContextInitializer> initializers) { 250 if (logger.isDebugEnabled()) { 251 logMappings("filters", initializers, Filter.class, 252 FilterRegistrationBean.class); 253 logMappings("servlets", initializers, Servlet.class, 254 ServletRegistrationBean.class); 255 } 256 } 257 258 private void logMappings(String name, 259 MultiValueMap<Class<?>, ServletContextInitializer> initializers, 260 Class<?> type, Class<? extends RegistrationBean> registrationType) { 261 List<ServletContextInitializer> registrations = new ArrayList<>(); 262 registrations.addAll( 263 initializers.getOrDefault(registrationType, Collections.emptyList())); 264 registrations.addAll(initializers.getOrDefault(type, Collections.emptyList())); 265 String info = registrations.stream().map(Object::toString) 266 .collect(Collectors.joining(", ")); 267 logger.debug("Mapping " + name + ": " + info); 268 } 269 270 @Override 271 public Iterator<ServletContextInitializer> iterator() { 272 return this.sortedList.iterator(); 273 } 274 275 @Override 276 public int size() { 277 return this.sortedList.size(); 278 } 279 280 /** 281 * Adapter to convert a given Bean type into a {@link RegistrationBean} (and hence a 282 * {@link ServletContextInitializer}). 283 * 284 * @param <T> the type of the Bean to adapt 285 */ 286 @FunctionalInterface 287 protected interface RegistrationBeanAdapter<T> { 288 289 RegistrationBean createRegistrationBean(String name, T source, 290 int totalNumberOfSourceBeans); 291 292 } 293 294 /** 295 * {@link RegistrationBeanAdapter} for {@link Servlet} beans. 296 */ 297 private static class ServletRegistrationBeanAdapter 298 implements RegistrationBeanAdapter<Servlet> { 299 300 private final MultipartConfigElement multipartConfig; 301 302 ServletRegistrationBeanAdapter(MultipartConfigElement multipartConfig) { 303 this.multipartConfig = multipartConfig; 304 } 305 306 @Override 307 public RegistrationBean createRegistrationBean(String name, Servlet source, 308 int totalNumberOfSourceBeans) { 309 String url = (totalNumberOfSourceBeans != 1) ? "/" + name + "/" : "/"; 310 if (name.equals(DISPATCHER_SERVLET_NAME)) { 311 url = "/"; // always map the main dispatcherServlet to "/" 312 } 313 ServletRegistrationBean<Servlet> bean = new ServletRegistrationBean<>(source, 314 url); 315 bean.setName(name); 316 bean.setMultipartConfig(this.multipartConfig); 317 return bean; 318 } 319 320 } 321 322 /** 323 * {@link RegistrationBeanAdapter} for {@link Filter} beans. 324 */ 325 private static class FilterRegistrationBeanAdapter 326 implements RegistrationBeanAdapter<Filter> { 327 328 @Override 329 public RegistrationBean createRegistrationBean(String name, Filter source, 330 int totalNumberOfSourceBeans) { 331 FilterRegistrationBean<Filter> bean = new FilterRegistrationBean<>(source); 332 bean.setName(name); 333 return bean; 334 } 335 336 } 337 338 /** 339 * {@link RegistrationBeanAdapter} for certain {@link EventListener} beans. 340 */ 341 private static class ServletListenerRegistrationBeanAdapter 342 implements RegistrationBeanAdapter<EventListener> { 343 344 @Override 345 public RegistrationBean createRegistrationBean(String name, EventListener source, 346 int totalNumberOfSourceBeans) { 347 return new ServletListenerRegistrationBean<>(source); 348 } 349 350 } 351 352}