001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.test.mock.mockito;
018
019import java.beans.PropertyDescriptor;
020import java.lang.reflect.Field;
021import java.util.Arrays;
022import java.util.Collection;
023import java.util.HashMap;
024import java.util.LinkedHashMap;
025import java.util.LinkedHashSet;
026import java.util.Map;
027import java.util.Set;
028import java.util.TreeSet;
029
030import org.springframework.aop.scope.ScopedProxyUtils;
031import org.springframework.beans.BeansException;
032import org.springframework.beans.PropertyValues;
033import org.springframework.beans.factory.BeanClassLoaderAware;
034import org.springframework.beans.factory.BeanCreationException;
035import org.springframework.beans.factory.BeanFactory;
036import org.springframework.beans.factory.BeanFactoryAware;
037import org.springframework.beans.factory.BeanFactoryUtils;
038import org.springframework.beans.factory.FactoryBean;
039import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
040import org.springframework.beans.factory.config.BeanDefinition;
041import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
042import org.springframework.beans.factory.config.BeanPostProcessor;
043import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
044import org.springframework.beans.factory.config.ConstructorArgumentValues;
045import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
046import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessorAdapter;
047import org.springframework.beans.factory.config.RuntimeBeanReference;
048import org.springframework.beans.factory.support.BeanDefinitionRegistry;
049import org.springframework.beans.factory.support.BeanNameGenerator;
050import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
051import org.springframework.beans.factory.support.RootBeanDefinition;
052import org.springframework.context.ApplicationContext;
053import org.springframework.context.annotation.ConfigurationClassPostProcessor;
054import org.springframework.core.Conventions;
055import org.springframework.core.Ordered;
056import org.springframework.core.PriorityOrdered;
057import org.springframework.core.ResolvableType;
058import org.springframework.test.context.junit4.SpringRunner;
059import org.springframework.util.Assert;
060import org.springframework.util.ClassUtils;
061import org.springframework.util.ObjectUtils;
062import org.springframework.util.ReflectionUtils;
063import org.springframework.util.StringUtils;
064
065/**
066 * A {@link BeanFactoryPostProcessor} used to register and inject
067 * {@link MockBean @MockBeans} with the {@link ApplicationContext}. An initial set of
068 * definitions can be passed to the processor with additional definitions being
069 * automatically created from {@code @Configuration} classes that use
070 * {@link MockBean @MockBean}.
071 *
072 * @author Phillip Webb
073 * @author Andy Wilkinson
074 * @author Stephane Nicoll
075 * @author Andreas Neiser
076 * @since 1.4.0
077 */
078public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAdapter
079                implements BeanClassLoaderAware, BeanFactoryAware, BeanFactoryPostProcessor,
080                Ordered {
081
082        private static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType";
083
084        private static final String BEAN_NAME = MockitoPostProcessor.class.getName();
085
086        private static final String CONFIGURATION_CLASS_ATTRIBUTE = Conventions
087                        .getQualifiedAttributeName(ConfigurationClassPostProcessor.class,
088                                        "configurationClass");
089
090        private static final BeanNameGenerator beanNameGenerator = new DefaultBeanNameGenerator();
091
092        private final Set<Definition> definitions;
093
094        private ClassLoader classLoader;
095
096        private BeanFactory beanFactory;
097
098        private final MockitoBeans mockitoBeans = new MockitoBeans();
099
100        private Map<Definition, String> beanNameRegistry = new HashMap<>();
101
102        private Map<Field, String> fieldRegistry = new HashMap<>();
103
104        private Map<String, SpyDefinition> spies = new HashMap<>();
105
106        /**
107         * Create a new {@link MockitoPostProcessor} instance with the given initial
108         * definitions.
109         * @param definitions the initial definitions
110         */
111        public MockitoPostProcessor(Set<Definition> definitions) {
112                this.definitions = definitions;
113        }
114
115        @Override
116        public void setBeanClassLoader(ClassLoader classLoader) {
117                this.classLoader = classLoader;
118        }
119
120        @Override
121        public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
122                Assert.isInstanceOf(ConfigurableListableBeanFactory.class, beanFactory,
123                                "Mock beans can only be used with a ConfigurableListableBeanFactory");
124                this.beanFactory = beanFactory;
125        }
126
127        @Override
128        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
129                        throws BeansException {
130                Assert.isInstanceOf(BeanDefinitionRegistry.class, beanFactory,
131                                "@MockBean can only be used on bean factories that "
132                                                + "implement BeanDefinitionRegistry");
133                postProcessBeanFactory(beanFactory, (BeanDefinitionRegistry) beanFactory);
134        }
135
136        private void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory,
137                        BeanDefinitionRegistry registry) {
138                beanFactory.registerSingleton(MockitoBeans.class.getName(), this.mockitoBeans);
139                DefinitionsParser parser = new DefinitionsParser(this.definitions);
140                for (Class<?> configurationClass : getConfigurationClasses(beanFactory)) {
141                        parser.parse(configurationClass);
142                }
143                Set<Definition> definitions = parser.getDefinitions();
144                for (Definition definition : definitions) {
145                        Field field = parser.getField(definition);
146                        register(beanFactory, registry, definition, field);
147                }
148        }
149
150        private Set<Class<?>> getConfigurationClasses(
151                        ConfigurableListableBeanFactory beanFactory) {
152                Set<Class<?>> configurationClasses = new LinkedHashSet<>();
153                for (BeanDefinition beanDefinition : getConfigurationBeanDefinitions(beanFactory)
154                                .values()) {
155                        configurationClasses.add(ClassUtils.resolveClassName(
156                                        beanDefinition.getBeanClassName(), this.classLoader));
157                }
158                return configurationClasses;
159        }
160
161        private Map<String, BeanDefinition> getConfigurationBeanDefinitions(
162                        ConfigurableListableBeanFactory beanFactory) {
163                Map<String, BeanDefinition> definitions = new LinkedHashMap<>();
164                for (String beanName : beanFactory.getBeanDefinitionNames()) {
165                        BeanDefinition definition = beanFactory.getBeanDefinition(beanName);
166                        if (definition.getAttribute(CONFIGURATION_CLASS_ATTRIBUTE) != null) {
167                                definitions.put(beanName, definition);
168                        }
169                }
170                return definitions;
171        }
172
173        private void register(ConfigurableListableBeanFactory beanFactory,
174                        BeanDefinitionRegistry registry, Definition definition, Field field) {
175                if (definition instanceof MockDefinition) {
176                        registerMock(beanFactory, registry, (MockDefinition) definition, field);
177                }
178                else if (definition instanceof SpyDefinition) {
179                        registerSpy(beanFactory, registry, (SpyDefinition) definition, field);
180                }
181        }
182
183        private void registerMock(ConfigurableListableBeanFactory beanFactory,
184                        BeanDefinitionRegistry registry, MockDefinition definition, Field field) {
185                RootBeanDefinition beanDefinition = createBeanDefinition(definition);
186                String beanName = getBeanName(beanFactory, registry, definition, beanDefinition);
187                String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName);
188                if (registry.containsBeanDefinition(transformedBeanName)) {
189                        BeanDefinition existing = registry.getBeanDefinition(transformedBeanName);
190                        copyBeanDefinitionDetails(existing, beanDefinition);
191                        registry.removeBeanDefinition(transformedBeanName);
192                }
193                registry.registerBeanDefinition(transformedBeanName, beanDefinition);
194                Object mock = definition.createMock(beanName + " bean");
195                beanFactory.registerSingleton(transformedBeanName, mock);
196                this.mockitoBeans.add(mock);
197                this.beanNameRegistry.put(definition, beanName);
198                if (field != null) {
199                        this.fieldRegistry.put(field, beanName);
200                }
201        }
202
203        private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) {
204                RootBeanDefinition definition = new RootBeanDefinition(
205                                mockDefinition.getTypeToMock().resolve());
206                definition.setTargetType(mockDefinition.getTypeToMock());
207                if (mockDefinition.getQualifier() != null) {
208                        mockDefinition.getQualifier().applyTo(definition);
209                }
210                return definition;
211        }
212
213        private String getBeanName(ConfigurableListableBeanFactory beanFactory,
214                        BeanDefinitionRegistry registry, MockDefinition mockDefinition,
215                        RootBeanDefinition beanDefinition) {
216                if (StringUtils.hasLength(mockDefinition.getName())) {
217                        return mockDefinition.getName();
218                }
219                Set<String> existingBeans = getExistingBeans(beanFactory,
220                                mockDefinition.getTypeToMock(), mockDefinition.getQualifier());
221                if (existingBeans.isEmpty()) {
222                        return MockitoPostProcessor.beanNameGenerator.generateBeanName(beanDefinition,
223                                        registry);
224                }
225                if (existingBeans.size() == 1) {
226                        return existingBeans.iterator().next();
227                }
228                String primaryCandidate = determinePrimaryCandidate(registry, existingBeans,
229                                mockDefinition.getTypeToMock());
230                if (primaryCandidate != null) {
231                        return primaryCandidate;
232                }
233                throw new IllegalStateException(
234                                "Unable to register mock bean " + mockDefinition.getTypeToMock()
235                                                + " expected a single matching bean to replace but found "
236                                                + existingBeans);
237        }
238
239        private void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) {
240                to.setPrimary(from.isPrimary());
241        }
242
243        private void registerSpy(ConfigurableListableBeanFactory beanFactory,
244                        BeanDefinitionRegistry registry, SpyDefinition spyDefinition, Field field) {
245                Set<String> existingBeans = getExistingBeans(beanFactory,
246                                spyDefinition.getTypeToSpy(), spyDefinition.getQualifier());
247                if (ObjectUtils.isEmpty(existingBeans)) {
248                        createSpy(registry, spyDefinition, field);
249                }
250                else {
251                        registerSpies(registry, spyDefinition, field, existingBeans);
252                }
253        }
254
255        private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory,
256                        ResolvableType type, QualifierDefinition qualifier) {
257                Set<String> candidates = new TreeSet<>();
258                for (String candidate : getExistingBeans(beanFactory, type)) {
259                        if (qualifier == null || qualifier.matches(beanFactory, candidate)) {
260                                candidates.add(candidate);
261                        }
262                }
263                return candidates;
264        }
265
266        private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory,
267                        ResolvableType type) {
268                Set<String> beans = new LinkedHashSet<>(
269                                Arrays.asList(beanFactory.getBeanNamesForType(type)));
270                String typeName = type.resolve(Object.class).getName();
271                for (String beanName : beanFactory.getBeanNamesForType(FactoryBean.class)) {
272                        beanName = BeanFactoryUtils.transformedBeanName(beanName);
273                        BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
274                        if (typeName.equals(beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE))) {
275                                beans.add(beanName);
276                        }
277                }
278                beans.removeIf(this::isScopedTarget);
279                return beans;
280        }
281
282        private boolean isScopedTarget(String beanName) {
283                try {
284                        return ScopedProxyUtils.isScopedTarget(beanName);
285                }
286                catch (Throwable ex) {
287                        return false;
288                }
289        }
290
291        private void createSpy(BeanDefinitionRegistry registry, SpyDefinition spyDefinition,
292                        Field field) {
293                RootBeanDefinition beanDefinition = new RootBeanDefinition(
294                                spyDefinition.getTypeToSpy().resolve());
295                String beanName = MockitoPostProcessor.beanNameGenerator
296                                .generateBeanName(beanDefinition, registry);
297                registry.registerBeanDefinition(beanName, beanDefinition);
298                registerSpy(spyDefinition, field, beanName);
299        }
300
301        private void registerSpies(BeanDefinitionRegistry registry,
302                        SpyDefinition spyDefinition, Field field, Collection<String> existingBeans) {
303                try {
304                        String beanName = determineBeanName(existingBeans, spyDefinition, registry);
305                        registerSpy(spyDefinition, field, beanName);
306                }
307                catch (RuntimeException ex) {
308                        throw new IllegalStateException(
309                                        "Unable to register spy bean " + spyDefinition.getTypeToSpy(), ex);
310                }
311        }
312
313        private String determineBeanName(Collection<String> existingBeans,
314                        SpyDefinition definition, BeanDefinitionRegistry registry) {
315                if (StringUtils.hasText(definition.getName())) {
316                        return definition.getName();
317                }
318                if (existingBeans.size() == 1) {
319                        return existingBeans.iterator().next();
320                }
321                return determinePrimaryCandidate(registry, existingBeans,
322                                definition.getTypeToSpy());
323        }
324
325        private String determinePrimaryCandidate(BeanDefinitionRegistry registry,
326                        Collection<String> candidateBeanNames, ResolvableType type) {
327                String primaryBeanName = null;
328                for (String candidateBeanName : candidateBeanNames) {
329                        BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName);
330                        if (beanDefinition.isPrimary()) {
331                                if (primaryBeanName != null) {
332                                        throw new NoUniqueBeanDefinitionException(type.resolve(),
333                                                        candidateBeanNames.size(),
334                                                        "more than one 'primary' bean found among candidates: "
335                                                                        + Arrays.asList(candidateBeanNames));
336                                }
337                                primaryBeanName = candidateBeanName;
338                        }
339                }
340                return primaryBeanName;
341        }
342
343        private void registerSpy(SpyDefinition definition, Field field, String beanName) {
344                this.spies.put(beanName, definition);
345                this.beanNameRegistry.put(definition, beanName);
346                if (field != null) {
347                        this.fieldRegistry.put(field, beanName);
348                }
349        }
350
351        protected final Object createSpyIfNecessary(Object bean, String beanName)
352                        throws BeansException {
353                SpyDefinition definition = this.spies.get(beanName);
354                if (definition != null) {
355                        bean = definition.createSpy(beanName, bean);
356                }
357                return bean;
358        }
359
360        @Override
361        public PropertyValues postProcessPropertyValues(PropertyValues pvs,
362                        PropertyDescriptor[] pds, final Object bean, String beanName)
363                        throws BeansException {
364                ReflectionUtils.doWithFields(bean.getClass(),
365                                (field) -> postProcessField(bean, field));
366                return pvs;
367        }
368
369        private void postProcessField(Object bean, Field field) {
370                String beanName = this.fieldRegistry.get(field);
371                if (StringUtils.hasText(beanName)) {
372                        inject(field, bean, beanName);
373                }
374        }
375
376        void inject(Field field, Object target, Definition definition) {
377                String beanName = this.beanNameRegistry.get(definition);
378                Assert.state(StringUtils.hasLength(beanName),
379                                () -> "No bean found for definition " + definition);
380                inject(field, target, beanName);
381        }
382
383        private void inject(Field field, Object target, String beanName) {
384                try {
385                        field.setAccessible(true);
386                        Assert.state(ReflectionUtils.getField(field, target) == null,
387                                        () -> "The field " + field + " cannot have an existing value");
388                        Object bean = this.beanFactory.getBean(beanName, field.getType());
389                        ReflectionUtils.setField(field, target, bean);
390                }
391                catch (Throwable ex) {
392                        throw new BeanCreationException("Could not inject field: " + field, ex);
393                }
394        }
395
396        @Override
397        public int getOrder() {
398                return Ordered.LOWEST_PRECEDENCE - 10;
399        }
400
401        /**
402         * Register the processor with a {@link BeanDefinitionRegistry}. Not required when
403         * using the {@link SpringRunner} as registration is automatic.
404         * @param registry the bean definition registry
405         */
406        public static void register(BeanDefinitionRegistry registry) {
407                register(registry, null);
408        }
409
410        /**
411         * Register the processor with a {@link BeanDefinitionRegistry}. Not required when
412         * using the {@link SpringRunner} as registration is automatic.
413         * @param registry the bean definition registry
414         * @param definitions the initial mock/spy definitions
415         */
416        public static void register(BeanDefinitionRegistry registry,
417                        Set<Definition> definitions) {
418                register(registry, MockitoPostProcessor.class, definitions);
419        }
420
421        /**
422         * Register the processor with a {@link BeanDefinitionRegistry}. Not required when
423         * using the {@link SpringRunner} as registration is automatic.
424         * @param registry the bean definition registry
425         * @param postProcessor the post processor class to register
426         * @param definitions the initial mock/spy definitions
427         */
428        @SuppressWarnings("unchecked")
429        public static void register(BeanDefinitionRegistry registry,
430                        Class<? extends MockitoPostProcessor> postProcessor,
431                        Set<Definition> definitions) {
432                SpyPostProcessor.register(registry);
433                BeanDefinition definition = getOrAddBeanDefinition(registry, postProcessor);
434                ValueHolder constructorArg = definition.getConstructorArgumentValues()
435                                .getIndexedArgumentValue(0, Set.class);
436                Set<Definition> existing = (Set<Definition>) constructorArg.getValue();
437                if (definitions != null) {
438                        existing.addAll(definitions);
439                }
440        }
441
442        private static BeanDefinition getOrAddBeanDefinition(BeanDefinitionRegistry registry,
443                        Class<? extends MockitoPostProcessor> postProcessor) {
444                if (!registry.containsBeanDefinition(BEAN_NAME)) {
445                        RootBeanDefinition definition = new RootBeanDefinition(postProcessor);
446                        definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
447                        ConstructorArgumentValues constructorArguments = definition
448                                        .getConstructorArgumentValues();
449                        constructorArguments.addIndexedArgumentValue(0,
450                                        new LinkedHashSet<MockDefinition>());
451                        registry.registerBeanDefinition(BEAN_NAME, definition);
452                        return definition;
453                }
454                return registry.getBeanDefinition(BEAN_NAME);
455        }
456
457        /**
458         * {@link BeanPostProcessor} to handle {@link SpyBean} definitions. Registered as a
459         * separate processor so that it can be ordered above AOP post processors.
460         */
461        static class SpyPostProcessor extends InstantiationAwareBeanPostProcessorAdapter
462                        implements PriorityOrdered {
463
464                private static final String BEAN_NAME = SpyPostProcessor.class.getName();
465
466                private final MockitoPostProcessor mockitoPostProcessor;
467
468                SpyPostProcessor(MockitoPostProcessor mockitoPostProcessor) {
469                        this.mockitoPostProcessor = mockitoPostProcessor;
470                }
471
472                @Override
473                public int getOrder() {
474                        return Ordered.HIGHEST_PRECEDENCE;
475                }
476
477                @Override
478                public Object getEarlyBeanReference(Object bean, String beanName)
479                                throws BeansException {
480                        return this.mockitoPostProcessor.createSpyIfNecessary(bean, beanName);
481                }
482
483                @Override
484                public Object postProcessAfterInitialization(Object bean, String beanName)
485                                throws BeansException {
486                        if (bean instanceof FactoryBean) {
487                                return bean;
488                        }
489                        return this.mockitoPostProcessor.createSpyIfNecessary(bean, beanName);
490                }
491
492                public static void register(BeanDefinitionRegistry registry) {
493                        if (!registry.containsBeanDefinition(BEAN_NAME)) {
494                                RootBeanDefinition definition = new RootBeanDefinition(
495                                                SpyPostProcessor.class);
496                                definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
497                                ConstructorArgumentValues constructorArguments = definition
498                                                .getConstructorArgumentValues();
499                                constructorArguments.addIndexedArgumentValue(0,
500                                                new RuntimeBeanReference(MockitoPostProcessor.BEAN_NAME));
501                                registry.registerBeanDefinition(BEAN_NAME, definition);
502                        }
503                }
504
505        }
506
507}