001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.autoconfigure;
018
019import java.util.Arrays;
020import java.util.HashSet;
021import java.util.Set;
022
023import org.springframework.beans.factory.BeanFactory;
024import org.springframework.beans.factory.BeanFactoryUtils;
025import org.springframework.beans.factory.FactoryBean;
026import org.springframework.beans.factory.ListableBeanFactory;
027import org.springframework.beans.factory.NoSuchBeanDefinitionException;
028import org.springframework.beans.factory.config.BeanDefinition;
029import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
030import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
031import org.springframework.util.StringUtils;
032
033/**
034 * Abstract base class for a {@link BeanFactoryPostProcessor} that can be used to
035 * dynamically declare that all beans of a specific type should depend on one or more
036 * specific beans.
037 *
038 * @author Marcel Overdijk
039 * @author Dave Syer
040 * @author Phillip Webb
041 * @author Andy Wilkinson
042 * @since 1.3.0
043 * @see BeanDefinition#setDependsOn(String[])
044 */
045public abstract class AbstractDependsOnBeanFactoryPostProcessor
046                implements BeanFactoryPostProcessor {
047
048        private final Class<?> beanClass;
049
050        private final Class<? extends FactoryBean<?>> factoryBeanClass;
051
052        private final String[] dependsOn;
053
054        protected AbstractDependsOnBeanFactoryPostProcessor(Class<?> beanClass,
055                        Class<? extends FactoryBean<?>> factoryBeanClass, String... dependsOn) {
056                this.beanClass = beanClass;
057                this.factoryBeanClass = factoryBeanClass;
058                this.dependsOn = dependsOn;
059        }
060
061        /**
062         * Create an instance with target bean class and dependencies.
063         * @param beanClass target bean class
064         * @param dependsOn dependencies
065         * @since 2.0.4
066         */
067        protected AbstractDependsOnBeanFactoryPostProcessor(Class<?> beanClass,
068                        String... dependsOn) {
069                this(beanClass, null, dependsOn);
070        }
071
072        @Override
073        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) {
074                for (String beanName : getBeanNames(beanFactory)) {
075                        BeanDefinition definition = getBeanDefinition(beanName, beanFactory);
076                        String[] dependencies = definition.getDependsOn();
077                        for (String bean : this.dependsOn) {
078                                dependencies = StringUtils.addStringToArray(dependencies, bean);
079                        }
080                        definition.setDependsOn(dependencies);
081                }
082        }
083
084        private Iterable<String> getBeanNames(ListableBeanFactory beanFactory) {
085                Set<String> names = new HashSet<>();
086                names.addAll(Arrays.asList(BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
087                                beanFactory, this.beanClass, true, false)));
088                if (this.factoryBeanClass != null) {
089                        for (String factoryBeanName : BeanFactoryUtils
090                                        .beanNamesForTypeIncludingAncestors(beanFactory,
091                                                        this.factoryBeanClass, true, false)) {
092                                names.add(BeanFactoryUtils.transformedBeanName(factoryBeanName));
093                        }
094                }
095                return names;
096        }
097
098        private static BeanDefinition getBeanDefinition(String beanName,
099                        ConfigurableListableBeanFactory beanFactory) {
100                try {
101                        return beanFactory.getBeanDefinition(beanName);
102                }
103                catch (NoSuchBeanDefinitionException ex) {
104                        BeanFactory parentBeanFactory = beanFactory.getParentBeanFactory();
105                        if (parentBeanFactory instanceof ConfigurableListableBeanFactory) {
106                                return getBeanDefinition(beanName,
107                                                (ConfigurableListableBeanFactory) parentBeanFactory);
108                        }
109                        throw ex;
110                }
111        }
112
113}