001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.test.context;
018
019import java.lang.annotation.Annotation;
020import java.util.Collections;
021import java.util.LinkedHashMap;
022import java.util.Map;
023import java.util.Set;
024
025import org.springframework.beans.factory.config.BeanDefinition;
026import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
027import org.springframework.core.type.filter.AnnotationTypeFilter;
028import org.springframework.util.Assert;
029import org.springframework.util.ClassUtils;
030
031/**
032 * Utility class to find a class annotated with a particular annotation in a hierarchy.
033 *
034 * @author Phillip Webb
035 * @author Artsiom Yudovin
036 * @author Stephane Nicoll
037 * @since 2.1.0
038 */
039public final class AnnotatedClassFinder {
040
041        private static final Map<String, Class<?>> cache = Collections
042                        .synchronizedMap(new Cache(40));
043
044        private final Class<? extends Annotation> annotationType;
045
046        private final ClassPathScanningCandidateComponentProvider scanner;
047
048        /**
049         * Create a new instance with the {@code annotationType} to find.
050         * @param annotationType the annotation to find
051         */
052        public AnnotatedClassFinder(Class<? extends Annotation> annotationType) {
053                Assert.notNull(annotationType, "AnnotationType must not be null");
054                this.annotationType = annotationType;
055                this.scanner = new ClassPathScanningCandidateComponentProvider(false);
056                this.scanner.addIncludeFilter(new AnnotationTypeFilter(annotationType));
057                this.scanner.setResourcePattern("*.class");
058        }
059
060        /**
061         * Find the first {@link Class} that is annotated with the target annotation, starting
062         * from the package defined by the given {@code source} up to the root.
063         * @param source the source class to use to initiate the search
064         * @return the first {@link Class} annotated with the target annotation within the
065         * hierarchy defined by the given {@code source} or {@code null} if none is found.
066         */
067        public Class<?> findFromClass(Class<?> source) {
068                Assert.notNull(source, "Source must not be null");
069                return findFromPackage(ClassUtils.getPackageName(source));
070        }
071
072        /**
073         * Find the first {@link Class} that is annotated with the target annotation, starting
074         * from the package defined by the given {@code source} up to the root.
075         * @param source the source package to use to initiate the search
076         * @return the first {@link Class} annotated with the target annotation within the
077         * hierarchy defined by the given {@code source} or {@code null} if none is found.
078         */
079        public Class<?> findFromPackage(String source) {
080                Assert.notNull(source, "Source must not be null");
081                Class<?> configuration = cache.get(source);
082                if (configuration == null) {
083                        configuration = scanPackage(source);
084                        cache.put(source, configuration);
085                }
086                return configuration;
087        }
088
089        private Class<?> scanPackage(String source) {
090                while (!source.isEmpty()) {
091                        Set<BeanDefinition> components = this.scanner.findCandidateComponents(source);
092                        if (!components.isEmpty()) {
093                                Assert.state(components.size() == 1,
094                                                () -> "Found multiple @" + this.annotationType.getSimpleName()
095                                                                + " annotated classes " + components);
096                                return ClassUtils.resolveClassName(
097                                                components.iterator().next().getBeanClassName(), null);
098                        }
099                        source = getParentPackage(source);
100                }
101                return null;
102        }
103
104        private String getParentPackage(String sourcePackage) {
105                int lastDot = sourcePackage.lastIndexOf('.');
106                return (lastDot != -1) ? sourcePackage.substring(0, lastDot) : "";
107        }
108
109        /**
110         * Cache implementation based on {@link LinkedHashMap}.
111         */
112        private static class Cache extends LinkedHashMap<String, Class<?>> {
113
114                private final int maxSize;
115
116                Cache(int maxSize) {
117                        super(16, 0.75f, true);
118                        this.maxSize = maxSize;
119                }
120
121                @Override
122                protected boolean removeEldestEntry(Map.Entry<String, Class<?>> eldest) {
123                        return size() > this.maxSize;
124                }
125
126        }
127
128}