001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.web.servlet;
018
019import java.util.AbstractCollection;
020import java.util.ArrayList;
021import java.util.Arrays;
022import java.util.Collections;
023import java.util.EventListener;
024import java.util.HashSet;
025import java.util.Iterator;
026import java.util.LinkedHashMap;
027import java.util.List;
028import java.util.Map;
029import java.util.Map.Entry;
030import java.util.Set;
031import java.util.stream.Collectors;
032
033import javax.servlet.Filter;
034import javax.servlet.MultipartConfigElement;
035import javax.servlet.Servlet;
036
037import org.apache.commons.logging.Log;
038import org.apache.commons.logging.LogFactory;
039
040import org.springframework.aop.scope.ScopedProxyUtils;
041import org.springframework.beans.factory.ListableBeanFactory;
042import org.springframework.beans.factory.support.BeanDefinitionRegistry;
043import org.springframework.core.annotation.AnnotationAwareOrderComparator;
044import org.springframework.util.LinkedMultiValueMap;
045import org.springframework.util.MultiValueMap;
046
047/**
048 * A collection {@link ServletContextInitializer}s obtained from a
049 * {@link ListableBeanFactory}. Includes all {@link ServletContextInitializer} beans and
050 * also adapts {@link Servlet}, {@link Filter} and certain {@link EventListener} beans.
051 * <p>
052 * Items are sorted so that adapted beans are top ({@link Servlet}, {@link Filter} then
053 * {@link EventListener}) and direct {@link ServletContextInitializer} beans are at the
054 * end. Further sorting is applied within these groups using the
055 * {@link AnnotationAwareOrderComparator}.
056 *
057 * @author Dave Syer
058 * @author Phillip Webb
059 * @author Brian Clozel
060 * @since 1.4.0
061 */
062public class ServletContextInitializerBeans
063                extends AbstractCollection<ServletContextInitializer> {
064
065        private static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet";
066
067        private static final Log logger = LogFactory
068                        .getLog(ServletContextInitializerBeans.class);
069
070        /**
071         * Seen bean instances or bean names.
072         */
073        private final Set<Object> seen = new HashSet<>();
074
075        private final MultiValueMap<Class<?>, ServletContextInitializer> initializers;
076
077        private final List<Class<? extends ServletContextInitializer>> initializerTypes;
078
079        private List<ServletContextInitializer> sortedList;
080
081        @SafeVarargs
082        public ServletContextInitializerBeans(ListableBeanFactory beanFactory,
083                        Class<? extends ServletContextInitializer>... initializerTypes) {
084                this.initializers = new LinkedMultiValueMap<>();
085                this.initializerTypes = (initializerTypes.length != 0)
086                                ? Arrays.asList(initializerTypes)
087                                : Collections.singletonList(ServletContextInitializer.class);
088                addServletContextInitializerBeans(beanFactory);
089                addAdaptableBeans(beanFactory);
090                List<ServletContextInitializer> sortedInitializers = this.initializers.values()
091                                .stream()
092                                .flatMap((value) -> value.stream()
093                                                .sorted(AnnotationAwareOrderComparator.INSTANCE))
094                                .collect(Collectors.toList());
095                this.sortedList = Collections.unmodifiableList(sortedInitializers);
096                logMappings(this.initializers);
097        }
098
099        private void addServletContextInitializerBeans(ListableBeanFactory beanFactory) {
100                for (Class<? extends ServletContextInitializer> initializerType : this.initializerTypes) {
101                        for (Entry<String, ? extends ServletContextInitializer> initializerBean : getOrderedBeansOfType(
102                                        beanFactory, initializerType)) {
103                                addServletContextInitializerBean(initializerBean.getKey(),
104                                                initializerBean.getValue(), beanFactory);
105                        }
106                }
107        }
108
109        private void addServletContextInitializerBean(String beanName,
110                        ServletContextInitializer initializer, ListableBeanFactory beanFactory) {
111                if (initializer instanceof ServletRegistrationBean) {
112                        Servlet source = ((ServletRegistrationBean<?>) initializer).getServlet();
113                        addServletContextInitializerBean(Servlet.class, beanName, initializer,
114                                        beanFactory, source);
115                }
116                else if (initializer instanceof FilterRegistrationBean) {
117                        Filter source = ((FilterRegistrationBean<?>) initializer).getFilter();
118                        addServletContextInitializerBean(Filter.class, beanName, initializer,
119                                        beanFactory, source);
120                }
121                else if (initializer instanceof DelegatingFilterProxyRegistrationBean) {
122                        String source = ((DelegatingFilterProxyRegistrationBean) initializer)
123                                        .getTargetBeanName();
124                        addServletContextInitializerBean(Filter.class, beanName, initializer,
125                                        beanFactory, source);
126                }
127                else if (initializer instanceof ServletListenerRegistrationBean) {
128                        EventListener source = ((ServletListenerRegistrationBean<?>) initializer)
129                                        .getListener();
130                        addServletContextInitializerBean(EventListener.class, beanName, initializer,
131                                        beanFactory, source);
132                }
133                else {
134                        addServletContextInitializerBean(ServletContextInitializer.class, beanName,
135                                        initializer, beanFactory, initializer);
136                }
137        }
138
139        private void addServletContextInitializerBean(Class<?> type, String beanName,
140                        ServletContextInitializer initializer, ListableBeanFactory beanFactory,
141                        Object source) {
142                this.initializers.add(type, initializer);
143                if (source != null) {
144                        // Mark the underlying source as seen in case it wraps an existing bean
145                        this.seen.add(source);
146                }
147                if (logger.isTraceEnabled()) {
148                        String resourceDescription = getResourceDescription(beanName, beanFactory);
149                        int order = getOrder(initializer);
150                        logger.trace("Added existing " + type.getSimpleName() + " initializer bean '"
151                                        + beanName + "'; order=" + order + ", resource="
152                                        + resourceDescription);
153                }
154        }
155
156        private String getResourceDescription(String beanName,
157                        ListableBeanFactory beanFactory) {
158                if (beanFactory instanceof BeanDefinitionRegistry) {
159                        BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
160                        return registry.getBeanDefinition(beanName).getResourceDescription();
161                }
162                return "unknown";
163        }
164
165        @SuppressWarnings("unchecked")
166        protected void addAdaptableBeans(ListableBeanFactory beanFactory) {
167                MultipartConfigElement multipartConfig = getMultipartConfig(beanFactory);
168                addAsRegistrationBean(beanFactory, Servlet.class,
169                                new ServletRegistrationBeanAdapter(multipartConfig));
170                addAsRegistrationBean(beanFactory, Filter.class,
171                                new FilterRegistrationBeanAdapter());
172                for (Class<?> listenerType : ServletListenerRegistrationBean
173                                .getSupportedTypes()) {
174                        addAsRegistrationBean(beanFactory, EventListener.class,
175                                        (Class<EventListener>) listenerType,
176                                        new ServletListenerRegistrationBeanAdapter());
177                }
178        }
179
180        private MultipartConfigElement getMultipartConfig(ListableBeanFactory beanFactory) {
181                List<Entry<String, MultipartConfigElement>> beans = getOrderedBeansOfType(
182                                beanFactory, MultipartConfigElement.class);
183                return beans.isEmpty() ? null : beans.get(0).getValue();
184        }
185
186        protected <T> void addAsRegistrationBean(ListableBeanFactory beanFactory,
187                        Class<T> type, RegistrationBeanAdapter<T> adapter) {
188                addAsRegistrationBean(beanFactory, type, type, adapter);
189        }
190
191        private <T, B extends T> void addAsRegistrationBean(ListableBeanFactory beanFactory,
192                        Class<T> type, Class<B> beanType, RegistrationBeanAdapter<T> adapter) {
193                List<Map.Entry<String, B>> entries = getOrderedBeansOfType(beanFactory, beanType,
194                                this.seen);
195                for (Entry<String, B> entry : entries) {
196                        String beanName = entry.getKey();
197                        B bean = entry.getValue();
198                        if (this.seen.add(bean)) {
199                                // One that we haven't already seen
200                                RegistrationBean registration = adapter.createRegistrationBean(beanName,
201                                                bean, entries.size());
202                                int order = getOrder(bean);
203                                registration.setOrder(order);
204                                this.initializers.add(type, registration);
205                                if (logger.isTraceEnabled()) {
206                                        logger.trace(
207                                                        "Created " + type.getSimpleName() + " initializer for bean '"
208                                                                        + beanName + "'; order=" + order + ", resource="
209                                                                        + getResourceDescription(beanName, beanFactory));
210                                }
211                        }
212                }
213        }
214
215        private int getOrder(Object value) {
216                return new AnnotationAwareOrderComparator() {
217                        @Override
218                        public int getOrder(Object obj) {
219                                return super.getOrder(obj);
220                        }
221                }.getOrder(value);
222        }
223
224        private <T> List<Entry<String, T>> getOrderedBeansOfType(
225                        ListableBeanFactory beanFactory, Class<T> type) {
226                return getOrderedBeansOfType(beanFactory, type, Collections.emptySet());
227        }
228
229        private <T> List<Entry<String, T>> getOrderedBeansOfType(
230                        ListableBeanFactory beanFactory, Class<T> type, Set<?> excludes) {
231                String[] names = beanFactory.getBeanNamesForType(type, true, false);
232                Map<String, T> map = new LinkedHashMap<>();
233                for (String name : names) {
234                        if (!excludes.contains(name) && !ScopedProxyUtils.isScopedTarget(name)) {
235                                T bean = beanFactory.getBean(name, type);
236                                if (!excludes.contains(bean)) {
237                                        map.put(name, bean);
238                                }
239                        }
240                }
241                List<Entry<String, T>> beans = new ArrayList<>();
242                beans.addAll(map.entrySet());
243                beans.sort((o1, o2) -> AnnotationAwareOrderComparator.INSTANCE
244                                .compare(o1.getValue(), o2.getValue()));
245                return beans;
246        }
247
248        private void logMappings(
249                        MultiValueMap<Class<?>, ServletContextInitializer> initializers) {
250                if (logger.isDebugEnabled()) {
251                        logMappings("filters", initializers, Filter.class,
252                                        FilterRegistrationBean.class);
253                        logMappings("servlets", initializers, Servlet.class,
254                                        ServletRegistrationBean.class);
255                }
256        }
257
258        private void logMappings(String name,
259                        MultiValueMap<Class<?>, ServletContextInitializer> initializers,
260                        Class<?> type, Class<? extends RegistrationBean> registrationType) {
261                List<ServletContextInitializer> registrations = new ArrayList<>();
262                registrations.addAll(
263                                initializers.getOrDefault(registrationType, Collections.emptyList()));
264                registrations.addAll(initializers.getOrDefault(type, Collections.emptyList()));
265                String info = registrations.stream().map(Object::toString)
266                                .collect(Collectors.joining(", "));
267                logger.debug("Mapping " + name + ": " + info);
268        }
269
270        @Override
271        public Iterator<ServletContextInitializer> iterator() {
272                return this.sortedList.iterator();
273        }
274
275        @Override
276        public int size() {
277                return this.sortedList.size();
278        }
279
280        /**
281         * Adapter to convert a given Bean type into a {@link RegistrationBean} (and hence a
282         * {@link ServletContextInitializer}).
283         *
284         * @param <T> the type of the Bean to adapt
285         */
286        @FunctionalInterface
287        protected interface RegistrationBeanAdapter<T> {
288
289                RegistrationBean createRegistrationBean(String name, T source,
290                                int totalNumberOfSourceBeans);
291
292        }
293
294        /**
295         * {@link RegistrationBeanAdapter} for {@link Servlet} beans.
296         */
297        private static class ServletRegistrationBeanAdapter
298                        implements RegistrationBeanAdapter<Servlet> {
299
300                private final MultipartConfigElement multipartConfig;
301
302                ServletRegistrationBeanAdapter(MultipartConfigElement multipartConfig) {
303                        this.multipartConfig = multipartConfig;
304                }
305
306                @Override
307                public RegistrationBean createRegistrationBean(String name, Servlet source,
308                                int totalNumberOfSourceBeans) {
309                        String url = (totalNumberOfSourceBeans != 1) ? "/" + name + "/" : "/";
310                        if (name.equals(DISPATCHER_SERVLET_NAME)) {
311                                url = "/"; // always map the main dispatcherServlet to "/"
312                        }
313                        ServletRegistrationBean<Servlet> bean = new ServletRegistrationBean<>(source,
314                                        url);
315                        bean.setName(name);
316                        bean.setMultipartConfig(this.multipartConfig);
317                        return bean;
318                }
319
320        }
321
322        /**
323         * {@link RegistrationBeanAdapter} for {@link Filter} beans.
324         */
325        private static class FilterRegistrationBeanAdapter
326                        implements RegistrationBeanAdapter<Filter> {
327
328                @Override
329                public RegistrationBean createRegistrationBean(String name, Filter source,
330                                int totalNumberOfSourceBeans) {
331                        FilterRegistrationBean<Filter> bean = new FilterRegistrationBean<>(source);
332                        bean.setName(name);
333                        return bean;
334                }
335
336        }
337
338        /**
339         * {@link RegistrationBeanAdapter} for certain {@link EventListener} beans.
340         */
341        private static class ServletListenerRegistrationBeanAdapter
342                        implements RegistrationBeanAdapter<EventListener> {
343
344                @Override
345                public RegistrationBean createRegistrationBean(String name, EventListener source,
346                                int totalNumberOfSourceBeans) {
347                        return new ServletListenerRegistrationBean<>(source);
348                }
349
350        }
351
352}