001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.context;
018
019import java.util.ArrayList;
020import java.util.Collections;
021import java.util.HashSet;
022import java.util.LinkedHashSet;
023import java.util.List;
024import java.util.Set;
025
026import org.apache.commons.logging.Log;
027import org.apache.commons.logging.LogFactory;
028
029import org.springframework.beans.BeansException;
030import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
031import org.springframework.beans.factory.config.BeanDefinition;
032import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
033import org.springframework.beans.factory.support.BeanDefinitionRegistry;
034import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
035import org.springframework.context.ApplicationContextInitializer;
036import org.springframework.context.ConfigurableApplicationContext;
037import org.springframework.context.annotation.ComponentScan;
038import org.springframework.core.Ordered;
039import org.springframework.core.PriorityOrdered;
040import org.springframework.core.annotation.AnnotationAttributes;
041import org.springframework.core.type.AnnotationMetadata;
042import org.springframework.util.ClassUtils;
043import org.springframework.util.StringUtils;
044
045/**
046 * {@link ApplicationContextInitializer} to report warnings for common misconfiguration
047 * mistakes.
048 *
049 * @author Phillip Webb
050 * @since 1.2.0
051 */
052public class ConfigurationWarningsApplicationContextInitializer
053                implements ApplicationContextInitializer<ConfigurableApplicationContext> {
054
055        private static final Log logger = LogFactory
056                        .getLog(ConfigurationWarningsApplicationContextInitializer.class);
057
058        @Override
059        public void initialize(ConfigurableApplicationContext context) {
060                context.addBeanFactoryPostProcessor(
061                                new ConfigurationWarningsPostProcessor(getChecks()));
062        }
063
064        /**
065         * Returns the checks that should be applied.
066         * @return the checks to apply
067         */
068        protected Check[] getChecks() {
069                return new Check[] { new ComponentScanPackageCheck() };
070        }
071
072        /**
073         * {@link BeanDefinitionRegistryPostProcessor} to report warnings.
074         */
075        protected static final class ConfigurationWarningsPostProcessor
076                        implements PriorityOrdered, BeanDefinitionRegistryPostProcessor {
077
078                private Check[] checks;
079
080                public ConfigurationWarningsPostProcessor(Check[] checks) {
081                        this.checks = checks;
082                }
083
084                @Override
085                public int getOrder() {
086                        return Ordered.LOWEST_PRECEDENCE - 1;
087                }
088
089                @Override
090                public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
091                                throws BeansException {
092                }
093
094                @Override
095                public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry)
096                                throws BeansException {
097                        for (Check check : this.checks) {
098                                String message = check.getWarning(registry);
099                                if (StringUtils.hasLength(message)) {
100                                        warn(message);
101                                }
102                        }
103
104                }
105
106                private void warn(String message) {
107                        if (logger.isWarnEnabled()) {
108                                logger.warn(String.format("%n%n** WARNING ** : %s%n%n", message));
109                        }
110                }
111
112        }
113
114        /**
115         * A single check that can be applied.
116         */
117        @FunctionalInterface
118        protected interface Check {
119
120                /**
121                 * Returns a warning if the check fails or {@code null} if there are no problems.
122                 * @param registry the {@link BeanDefinitionRegistry}
123                 * @return a warning message or {@code null}
124                 */
125                String getWarning(BeanDefinitionRegistry registry);
126
127        }
128
129        /**
130         * {@link Check} for {@code @ComponentScan} on problematic package.
131         */
132        protected static class ComponentScanPackageCheck implements Check {
133
134                private static final Set<String> PROBLEM_PACKAGES;
135
136                static {
137                        Set<String> packages = new HashSet<>();
138                        packages.add("org.springframework");
139                        packages.add("org");
140                        PROBLEM_PACKAGES = Collections.unmodifiableSet(packages);
141                }
142
143                @Override
144                public String getWarning(BeanDefinitionRegistry registry) {
145                        Set<String> scannedPackages = getComponentScanningPackages(registry);
146                        List<String> problematicPackages = getProblematicPackages(scannedPackages);
147                        if (problematicPackages.isEmpty()) {
148                                return null;
149                        }
150                        return "Your ApplicationContext is unlikely to "
151                                        + "start due to a @ComponentScan of "
152                                        + StringUtils.collectionToDelimitedString(problematicPackages, ", ")
153                                        + ".";
154                }
155
156                protected Set<String> getComponentScanningPackages(
157                                BeanDefinitionRegistry registry) {
158                        Set<String> packages = new LinkedHashSet<>();
159                        String[] names = registry.getBeanDefinitionNames();
160                        for (String name : names) {
161                                BeanDefinition definition = registry.getBeanDefinition(name);
162                                if (definition instanceof AnnotatedBeanDefinition) {
163                                        AnnotatedBeanDefinition annotatedDefinition = (AnnotatedBeanDefinition) definition;
164                                        addComponentScanningPackages(packages,
165                                                        annotatedDefinition.getMetadata());
166                                }
167                        }
168                        return packages;
169                }
170
171                private void addComponentScanningPackages(Set<String> packages,
172                                AnnotationMetadata metadata) {
173                        AnnotationAttributes attributes = AnnotationAttributes.fromMap(metadata
174                                        .getAnnotationAttributes(ComponentScan.class.getName(), true));
175                        if (attributes != null) {
176                                addPackages(packages, attributes.getStringArray("value"));
177                                addPackages(packages, attributes.getStringArray("basePackages"));
178                                addClasses(packages, attributes.getStringArray("basePackageClasses"));
179                                if (packages.isEmpty()) {
180                                        packages.add(ClassUtils.getPackageName(metadata.getClassName()));
181                                }
182                        }
183                }
184
185                private void addPackages(Set<String> packages, String[] values) {
186                        if (values != null) {
187                                Collections.addAll(packages, values);
188                        }
189                }
190
191                private void addClasses(Set<String> packages, String[] values) {
192                        if (values != null) {
193                                for (String value : values) {
194                                        packages.add(ClassUtils.getPackageName(value));
195                                }
196                        }
197                }
198
199                private List<String> getProblematicPackages(Set<String> scannedPackages) {
200                        List<String> problematicPackages = new ArrayList<>();
201                        for (String scannedPackage : scannedPackages) {
202                                if (isProblematicPackage(scannedPackage)) {
203                                        problematicPackages.add(getDisplayName(scannedPackage));
204                                }
205                        }
206                        return problematicPackages;
207                }
208
209                private boolean isProblematicPackage(String scannedPackage) {
210                        if (scannedPackage == null || scannedPackage.isEmpty()) {
211                                return true;
212                        }
213                        return PROBLEM_PACKAGES.contains(scannedPackage);
214                }
215
216                private String getDisplayName(String scannedPackage) {
217                        if (scannedPackage == null || scannedPackage.isEmpty()) {
218                                return "the default package";
219                        }
220                        return "'" + scannedPackage + "'";
221                }
222
223        }
224
225}