001/* 002 * Copyright 2012-2018 the original author or authors. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.springframework.boot.test.mock.mockito; 018 019import java.beans.PropertyDescriptor; 020import java.lang.reflect.Field; 021import java.util.Arrays; 022import java.util.Collection; 023import java.util.HashMap; 024import java.util.LinkedHashMap; 025import java.util.LinkedHashSet; 026import java.util.Map; 027import java.util.Set; 028import java.util.TreeSet; 029 030import org.springframework.aop.scope.ScopedProxyUtils; 031import org.springframework.beans.BeansException; 032import org.springframework.beans.PropertyValues; 033import org.springframework.beans.factory.BeanClassLoaderAware; 034import org.springframework.beans.factory.BeanCreationException; 035import org.springframework.beans.factory.BeanFactory; 036import org.springframework.beans.factory.BeanFactoryAware; 037import org.springframework.beans.factory.BeanFactoryUtils; 038import org.springframework.beans.factory.FactoryBean; 039import org.springframework.beans.factory.NoUniqueBeanDefinitionException; 040import org.springframework.beans.factory.config.BeanDefinition; 041import org.springframework.beans.factory.config.BeanFactoryPostProcessor; 042import org.springframework.beans.factory.config.BeanPostProcessor; 043import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; 044import org.springframework.beans.factory.config.ConstructorArgumentValues; 045import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; 046import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessorAdapter; 047import org.springframework.beans.factory.config.RuntimeBeanReference; 048import org.springframework.beans.factory.support.BeanDefinitionRegistry; 049import org.springframework.beans.factory.support.BeanNameGenerator; 050import org.springframework.beans.factory.support.DefaultBeanNameGenerator; 051import org.springframework.beans.factory.support.RootBeanDefinition; 052import org.springframework.context.ApplicationContext; 053import org.springframework.context.annotation.ConfigurationClassPostProcessor; 054import org.springframework.core.Conventions; 055import org.springframework.core.Ordered; 056import org.springframework.core.PriorityOrdered; 057import org.springframework.core.ResolvableType; 058import org.springframework.test.context.junit4.SpringRunner; 059import org.springframework.util.Assert; 060import org.springframework.util.ClassUtils; 061import org.springframework.util.ObjectUtils; 062import org.springframework.util.ReflectionUtils; 063import org.springframework.util.StringUtils; 064 065/** 066 * A {@link BeanFactoryPostProcessor} used to register and inject 067 * {@link MockBean @MockBeans} with the {@link ApplicationContext}. An initial set of 068 * definitions can be passed to the processor with additional definitions being 069 * automatically created from {@code @Configuration} classes that use 070 * {@link MockBean @MockBean}. 071 * 072 * @author Phillip Webb 073 * @author Andy Wilkinson 074 * @author Stephane Nicoll 075 * @author Andreas Neiser 076 * @since 1.4.0 077 */ 078public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAdapter 079 implements BeanClassLoaderAware, BeanFactoryAware, BeanFactoryPostProcessor, 080 Ordered { 081 082 private static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType"; 083 084 private static final String BEAN_NAME = MockitoPostProcessor.class.getName(); 085 086 private static final String CONFIGURATION_CLASS_ATTRIBUTE = Conventions 087 .getQualifiedAttributeName(ConfigurationClassPostProcessor.class, 088 "configurationClass"); 089 090 private static final BeanNameGenerator beanNameGenerator = new DefaultBeanNameGenerator(); 091 092 private final Set<Definition> definitions; 093 094 private ClassLoader classLoader; 095 096 private BeanFactory beanFactory; 097 098 private final MockitoBeans mockitoBeans = new MockitoBeans(); 099 100 private Map<Definition, String> beanNameRegistry = new HashMap<>(); 101 102 private Map<Field, String> fieldRegistry = new HashMap<>(); 103 104 private Map<String, SpyDefinition> spies = new HashMap<>(); 105 106 /** 107 * Create a new {@link MockitoPostProcessor} instance with the given initial 108 * definitions. 109 * @param definitions the initial definitions 110 */ 111 public MockitoPostProcessor(Set<Definition> definitions) { 112 this.definitions = definitions; 113 } 114 115 @Override 116 public void setBeanClassLoader(ClassLoader classLoader) { 117 this.classLoader = classLoader; 118 } 119 120 @Override 121 public void setBeanFactory(BeanFactory beanFactory) throws BeansException { 122 Assert.isInstanceOf(ConfigurableListableBeanFactory.class, beanFactory, 123 "Mock beans can only be used with a ConfigurableListableBeanFactory"); 124 this.beanFactory = beanFactory; 125 } 126 127 @Override 128 public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) 129 throws BeansException { 130 Assert.isInstanceOf(BeanDefinitionRegistry.class, beanFactory, 131 "@MockBean can only be used on bean factories that " 132 + "implement BeanDefinitionRegistry"); 133 postProcessBeanFactory(beanFactory, (BeanDefinitionRegistry) beanFactory); 134 } 135 136 private void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory, 137 BeanDefinitionRegistry registry) { 138 beanFactory.registerSingleton(MockitoBeans.class.getName(), this.mockitoBeans); 139 DefinitionsParser parser = new DefinitionsParser(this.definitions); 140 for (Class<?> configurationClass : getConfigurationClasses(beanFactory)) { 141 parser.parse(configurationClass); 142 } 143 Set<Definition> definitions = parser.getDefinitions(); 144 for (Definition definition : definitions) { 145 Field field = parser.getField(definition); 146 register(beanFactory, registry, definition, field); 147 } 148 } 149 150 private Set<Class<?>> getConfigurationClasses( 151 ConfigurableListableBeanFactory beanFactory) { 152 Set<Class<?>> configurationClasses = new LinkedHashSet<>(); 153 for (BeanDefinition beanDefinition : getConfigurationBeanDefinitions(beanFactory) 154 .values()) { 155 configurationClasses.add(ClassUtils.resolveClassName( 156 beanDefinition.getBeanClassName(), this.classLoader)); 157 } 158 return configurationClasses; 159 } 160 161 private Map<String, BeanDefinition> getConfigurationBeanDefinitions( 162 ConfigurableListableBeanFactory beanFactory) { 163 Map<String, BeanDefinition> definitions = new LinkedHashMap<>(); 164 for (String beanName : beanFactory.getBeanDefinitionNames()) { 165 BeanDefinition definition = beanFactory.getBeanDefinition(beanName); 166 if (definition.getAttribute(CONFIGURATION_CLASS_ATTRIBUTE) != null) { 167 definitions.put(beanName, definition); 168 } 169 } 170 return definitions; 171 } 172 173 private void register(ConfigurableListableBeanFactory beanFactory, 174 BeanDefinitionRegistry registry, Definition definition, Field field) { 175 if (definition instanceof MockDefinition) { 176 registerMock(beanFactory, registry, (MockDefinition) definition, field); 177 } 178 else if (definition instanceof SpyDefinition) { 179 registerSpy(beanFactory, registry, (SpyDefinition) definition, field); 180 } 181 } 182 183 private void registerMock(ConfigurableListableBeanFactory beanFactory, 184 BeanDefinitionRegistry registry, MockDefinition definition, Field field) { 185 RootBeanDefinition beanDefinition = createBeanDefinition(definition); 186 String beanName = getBeanName(beanFactory, registry, definition, beanDefinition); 187 String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName); 188 if (registry.containsBeanDefinition(transformedBeanName)) { 189 BeanDefinition existing = registry.getBeanDefinition(transformedBeanName); 190 copyBeanDefinitionDetails(existing, beanDefinition); 191 registry.removeBeanDefinition(transformedBeanName); 192 } 193 registry.registerBeanDefinition(transformedBeanName, beanDefinition); 194 Object mock = definition.createMock(beanName + " bean"); 195 beanFactory.registerSingleton(transformedBeanName, mock); 196 this.mockitoBeans.add(mock); 197 this.beanNameRegistry.put(definition, beanName); 198 if (field != null) { 199 this.fieldRegistry.put(field, beanName); 200 } 201 } 202 203 private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) { 204 RootBeanDefinition definition = new RootBeanDefinition( 205 mockDefinition.getTypeToMock().resolve()); 206 definition.setTargetType(mockDefinition.getTypeToMock()); 207 if (mockDefinition.getQualifier() != null) { 208 mockDefinition.getQualifier().applyTo(definition); 209 } 210 return definition; 211 } 212 213 private String getBeanName(ConfigurableListableBeanFactory beanFactory, 214 BeanDefinitionRegistry registry, MockDefinition mockDefinition, 215 RootBeanDefinition beanDefinition) { 216 if (StringUtils.hasLength(mockDefinition.getName())) { 217 return mockDefinition.getName(); 218 } 219 Set<String> existingBeans = getExistingBeans(beanFactory, 220 mockDefinition.getTypeToMock(), mockDefinition.getQualifier()); 221 if (existingBeans.isEmpty()) { 222 return MockitoPostProcessor.beanNameGenerator.generateBeanName(beanDefinition, 223 registry); 224 } 225 if (existingBeans.size() == 1) { 226 return existingBeans.iterator().next(); 227 } 228 String primaryCandidate = determinePrimaryCandidate(registry, existingBeans, 229 mockDefinition.getTypeToMock()); 230 if (primaryCandidate != null) { 231 return primaryCandidate; 232 } 233 throw new IllegalStateException( 234 "Unable to register mock bean " + mockDefinition.getTypeToMock() 235 + " expected a single matching bean to replace but found " 236 + existingBeans); 237 } 238 239 private void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) { 240 to.setPrimary(from.isPrimary()); 241 } 242 243 private void registerSpy(ConfigurableListableBeanFactory beanFactory, 244 BeanDefinitionRegistry registry, SpyDefinition spyDefinition, Field field) { 245 Set<String> existingBeans = getExistingBeans(beanFactory, 246 spyDefinition.getTypeToSpy(), spyDefinition.getQualifier()); 247 if (ObjectUtils.isEmpty(existingBeans)) { 248 createSpy(registry, spyDefinition, field); 249 } 250 else { 251 registerSpies(registry, spyDefinition, field, existingBeans); 252 } 253 } 254 255 private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory, 256 ResolvableType type, QualifierDefinition qualifier) { 257 Set<String> candidates = new TreeSet<>(); 258 for (String candidate : getExistingBeans(beanFactory, type)) { 259 if (qualifier == null || qualifier.matches(beanFactory, candidate)) { 260 candidates.add(candidate); 261 } 262 } 263 return candidates; 264 } 265 266 private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory, 267 ResolvableType type) { 268 Set<String> beans = new LinkedHashSet<>( 269 Arrays.asList(beanFactory.getBeanNamesForType(type))); 270 String typeName = type.resolve(Object.class).getName(); 271 for (String beanName : beanFactory.getBeanNamesForType(FactoryBean.class)) { 272 beanName = BeanFactoryUtils.transformedBeanName(beanName); 273 BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); 274 if (typeName.equals(beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE))) { 275 beans.add(beanName); 276 } 277 } 278 beans.removeIf(this::isScopedTarget); 279 return beans; 280 } 281 282 private boolean isScopedTarget(String beanName) { 283 try { 284 return ScopedProxyUtils.isScopedTarget(beanName); 285 } 286 catch (Throwable ex) { 287 return false; 288 } 289 } 290 291 private void createSpy(BeanDefinitionRegistry registry, SpyDefinition spyDefinition, 292 Field field) { 293 RootBeanDefinition beanDefinition = new RootBeanDefinition( 294 spyDefinition.getTypeToSpy().resolve()); 295 String beanName = MockitoPostProcessor.beanNameGenerator 296 .generateBeanName(beanDefinition, registry); 297 registry.registerBeanDefinition(beanName, beanDefinition); 298 registerSpy(spyDefinition, field, beanName); 299 } 300 301 private void registerSpies(BeanDefinitionRegistry registry, 302 SpyDefinition spyDefinition, Field field, Collection<String> existingBeans) { 303 try { 304 String beanName = determineBeanName(existingBeans, spyDefinition, registry); 305 registerSpy(spyDefinition, field, beanName); 306 } 307 catch (RuntimeException ex) { 308 throw new IllegalStateException( 309 "Unable to register spy bean " + spyDefinition.getTypeToSpy(), ex); 310 } 311 } 312 313 private String determineBeanName(Collection<String> existingBeans, 314 SpyDefinition definition, BeanDefinitionRegistry registry) { 315 if (StringUtils.hasText(definition.getName())) { 316 return definition.getName(); 317 } 318 if (existingBeans.size() == 1) { 319 return existingBeans.iterator().next(); 320 } 321 return determinePrimaryCandidate(registry, existingBeans, 322 definition.getTypeToSpy()); 323 } 324 325 private String determinePrimaryCandidate(BeanDefinitionRegistry registry, 326 Collection<String> candidateBeanNames, ResolvableType type) { 327 String primaryBeanName = null; 328 for (String candidateBeanName : candidateBeanNames) { 329 BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName); 330 if (beanDefinition.isPrimary()) { 331 if (primaryBeanName != null) { 332 throw new NoUniqueBeanDefinitionException(type.resolve(), 333 candidateBeanNames.size(), 334 "more than one 'primary' bean found among candidates: " 335 + Arrays.asList(candidateBeanNames)); 336 } 337 primaryBeanName = candidateBeanName; 338 } 339 } 340 return primaryBeanName; 341 } 342 343 private void registerSpy(SpyDefinition definition, Field field, String beanName) { 344 this.spies.put(beanName, definition); 345 this.beanNameRegistry.put(definition, beanName); 346 if (field != null) { 347 this.fieldRegistry.put(field, beanName); 348 } 349 } 350 351 protected final Object createSpyIfNecessary(Object bean, String beanName) 352 throws BeansException { 353 SpyDefinition definition = this.spies.get(beanName); 354 if (definition != null) { 355 bean = definition.createSpy(beanName, bean); 356 } 357 return bean; 358 } 359 360 @Override 361 public PropertyValues postProcessPropertyValues(PropertyValues pvs, 362 PropertyDescriptor[] pds, final Object bean, String beanName) 363 throws BeansException { 364 ReflectionUtils.doWithFields(bean.getClass(), 365 (field) -> postProcessField(bean, field)); 366 return pvs; 367 } 368 369 private void postProcessField(Object bean, Field field) { 370 String beanName = this.fieldRegistry.get(field); 371 if (StringUtils.hasText(beanName)) { 372 inject(field, bean, beanName); 373 } 374 } 375 376 void inject(Field field, Object target, Definition definition) { 377 String beanName = this.beanNameRegistry.get(definition); 378 Assert.state(StringUtils.hasLength(beanName), 379 () -> "No bean found for definition " + definition); 380 inject(field, target, beanName); 381 } 382 383 private void inject(Field field, Object target, String beanName) { 384 try { 385 field.setAccessible(true); 386 Assert.state(ReflectionUtils.getField(field, target) == null, 387 () -> "The field " + field + " cannot have an existing value"); 388 Object bean = this.beanFactory.getBean(beanName, field.getType()); 389 ReflectionUtils.setField(field, target, bean); 390 } 391 catch (Throwable ex) { 392 throw new BeanCreationException("Could not inject field: " + field, ex); 393 } 394 } 395 396 @Override 397 public int getOrder() { 398 return Ordered.LOWEST_PRECEDENCE - 10; 399 } 400 401 /** 402 * Register the processor with a {@link BeanDefinitionRegistry}. Not required when 403 * using the {@link SpringRunner} as registration is automatic. 404 * @param registry the bean definition registry 405 */ 406 public static void register(BeanDefinitionRegistry registry) { 407 register(registry, null); 408 } 409 410 /** 411 * Register the processor with a {@link BeanDefinitionRegistry}. Not required when 412 * using the {@link SpringRunner} as registration is automatic. 413 * @param registry the bean definition registry 414 * @param definitions the initial mock/spy definitions 415 */ 416 public static void register(BeanDefinitionRegistry registry, 417 Set<Definition> definitions) { 418 register(registry, MockitoPostProcessor.class, definitions); 419 } 420 421 /** 422 * Register the processor with a {@link BeanDefinitionRegistry}. Not required when 423 * using the {@link SpringRunner} as registration is automatic. 424 * @param registry the bean definition registry 425 * @param postProcessor the post processor class to register 426 * @param definitions the initial mock/spy definitions 427 */ 428 @SuppressWarnings("unchecked") 429 public static void register(BeanDefinitionRegistry registry, 430 Class<? extends MockitoPostProcessor> postProcessor, 431 Set<Definition> definitions) { 432 SpyPostProcessor.register(registry); 433 BeanDefinition definition = getOrAddBeanDefinition(registry, postProcessor); 434 ValueHolder constructorArg = definition.getConstructorArgumentValues() 435 .getIndexedArgumentValue(0, Set.class); 436 Set<Definition> existing = (Set<Definition>) constructorArg.getValue(); 437 if (definitions != null) { 438 existing.addAll(definitions); 439 } 440 } 441 442 private static BeanDefinition getOrAddBeanDefinition(BeanDefinitionRegistry registry, 443 Class<? extends MockitoPostProcessor> postProcessor) { 444 if (!registry.containsBeanDefinition(BEAN_NAME)) { 445 RootBeanDefinition definition = new RootBeanDefinition(postProcessor); 446 definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); 447 ConstructorArgumentValues constructorArguments = definition 448 .getConstructorArgumentValues(); 449 constructorArguments.addIndexedArgumentValue(0, 450 new LinkedHashSet<MockDefinition>()); 451 registry.registerBeanDefinition(BEAN_NAME, definition); 452 return definition; 453 } 454 return registry.getBeanDefinition(BEAN_NAME); 455 } 456 457 /** 458 * {@link BeanPostProcessor} to handle {@link SpyBean} definitions. Registered as a 459 * separate processor so that it can be ordered above AOP post processors. 460 */ 461 static class SpyPostProcessor extends InstantiationAwareBeanPostProcessorAdapter 462 implements PriorityOrdered { 463 464 private static final String BEAN_NAME = SpyPostProcessor.class.getName(); 465 466 private final MockitoPostProcessor mockitoPostProcessor; 467 468 SpyPostProcessor(MockitoPostProcessor mockitoPostProcessor) { 469 this.mockitoPostProcessor = mockitoPostProcessor; 470 } 471 472 @Override 473 public int getOrder() { 474 return Ordered.HIGHEST_PRECEDENCE; 475 } 476 477 @Override 478 public Object getEarlyBeanReference(Object bean, String beanName) 479 throws BeansException { 480 return this.mockitoPostProcessor.createSpyIfNecessary(bean, beanName); 481 } 482 483 @Override 484 public Object postProcessAfterInitialization(Object bean, String beanName) 485 throws BeansException { 486 if (bean instanceof FactoryBean) { 487 return bean; 488 } 489 return this.mockitoPostProcessor.createSpyIfNecessary(bean, beanName); 490 } 491 492 public static void register(BeanDefinitionRegistry registry) { 493 if (!registry.containsBeanDefinition(BEAN_NAME)) { 494 RootBeanDefinition definition = new RootBeanDefinition( 495 SpyPostProcessor.class); 496 definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); 497 ConstructorArgumentValues constructorArguments = definition 498 .getConstructorArgumentValues(); 499 constructorArguments.addIndexedArgumentValue(0, 500 new RuntimeBeanReference(MockitoPostProcessor.BEAN_NAME)); 501 registry.registerBeanDefinition(BEAN_NAME, definition); 502 } 503 } 504 505 } 506 507}