001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.test.mock.mockito;
018
019import java.util.Arrays;
020import java.util.HashSet;
021import java.util.Set;
022
023import org.mockito.Mockito;
024
025import org.springframework.beans.factory.NoSuchBeanDefinitionException;
026import org.springframework.beans.factory.config.BeanDefinition;
027import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
028import org.springframework.context.ApplicationContext;
029import org.springframework.context.ConfigurableApplicationContext;
030import org.springframework.core.Ordered;
031import org.springframework.test.context.TestContext;
032import org.springframework.test.context.TestExecutionListener;
033import org.springframework.test.context.support.AbstractTestExecutionListener;
034import org.springframework.util.ClassUtils;
035
036/**
037 * {@link TestExecutionListener} to reset any mock beans that have been marked with a
038 * {@link MockReset}.
039 *
040 * @author Phillip Webb
041 * @since 1.4.0
042 */
043public class ResetMocksTestExecutionListener extends AbstractTestExecutionListener {
044
045        private static final boolean MOCKITO_IS_PRESENT = ClassUtils.isPresent(
046                        "org.mockito.MockSettings",
047                        ResetMocksTestExecutionListener.class.getClassLoader());
048
049        @Override
050        public int getOrder() {
051                return Ordered.LOWEST_PRECEDENCE - 100;
052        }
053
054        @Override
055        public void beforeTestMethod(TestContext testContext) throws Exception {
056                if (MOCKITO_IS_PRESENT) {
057                        resetMocks(testContext.getApplicationContext(), MockReset.BEFORE);
058                }
059        }
060
061        @Override
062        public void afterTestMethod(TestContext testContext) throws Exception {
063                if (MOCKITO_IS_PRESENT) {
064                        resetMocks(testContext.getApplicationContext(), MockReset.AFTER);
065                }
066        }
067
068        private void resetMocks(ApplicationContext applicationContext, MockReset reset) {
069                if (applicationContext instanceof ConfigurableApplicationContext) {
070                        resetMocks((ConfigurableApplicationContext) applicationContext, reset);
071                }
072        }
073
074        private void resetMocks(ConfigurableApplicationContext applicationContext,
075                        MockReset reset) {
076                ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
077                String[] names = beanFactory.getBeanDefinitionNames();
078                Set<String> instantiatedSingletons = new HashSet<>(
079                                Arrays.asList(beanFactory.getSingletonNames()));
080                for (String name : names) {
081                        BeanDefinition definition = beanFactory.getBeanDefinition(name);
082                        if (definition.isSingleton() && instantiatedSingletons.contains(name)) {
083                                Object bean = beanFactory.getSingleton(name);
084                                if (reset.equals(MockReset.get(bean))) {
085                                        Mockito.reset(bean);
086                                }
087                        }
088                }
089                try {
090                        MockitoBeans mockedBeans = beanFactory.getBean(MockitoBeans.class);
091                        for (Object mockedBean : mockedBeans) {
092                                if (reset.equals(MockReset.get(mockedBean))) {
093                                        Mockito.reset(mockedBean);
094                                }
095                        }
096                }
097                catch (NoSuchBeanDefinitionException ex) {
098                        // Continue
099                }
100                if (applicationContext.getParent() != null) {
101                        resetMocks(applicationContext.getParent(), reset);
102                }
103        }
104
105}