001/*
002 * Copyright 2012-2017 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.loader.tools;
018
019import java.io.BufferedInputStream;
020import java.io.File;
021import java.io.FileFilter;
022import java.io.FileInputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.util.ArrayDeque;
026import java.util.ArrayList;
027import java.util.Arrays;
028import java.util.Collection;
029import java.util.Collections;
030import java.util.Comparator;
031import java.util.Deque;
032import java.util.Enumeration;
033import java.util.HashSet;
034import java.util.LinkedHashSet;
035import java.util.List;
036import java.util.Set;
037import java.util.jar.JarEntry;
038import java.util.jar.JarFile;
039
040import org.springframework.asm.AnnotationVisitor;
041import org.springframework.asm.ClassReader;
042import org.springframework.asm.ClassVisitor;
043import org.springframework.asm.MethodVisitor;
044import org.springframework.asm.Opcodes;
045import org.springframework.asm.Type;
046
047/**
048 * Finds any class with a {@code public static main} method by performing a breadth first
049 * search.
050 *
051 * @author Phillip Webb
052 * @author Andy Wilkinson
053 */
054public abstract class MainClassFinder {
055
056        private static final String DOT_CLASS = ".class";
057
058        private static final Type STRING_ARRAY_TYPE = Type.getType(String[].class);
059
060        private static final Type MAIN_METHOD_TYPE = Type.getMethodType(Type.VOID_TYPE,
061                        STRING_ARRAY_TYPE);
062
063        private static final String MAIN_METHOD_NAME = "main";
064
065        private static final FileFilter CLASS_FILE_FILTER = new FileFilter() {
066                @Override
067                public boolean accept(File file) {
068                        return (file.isFile() && file.getName().endsWith(DOT_CLASS));
069                }
070        };
071
072        private static final FileFilter PACKAGE_FOLDER_FILTER = new FileFilter() {
073                @Override
074                public boolean accept(File file) {
075                        return file.isDirectory() && !file.getName().startsWith(".");
076                }
077        };
078
079        /**
080         * Find the main class from a given folder.
081         * @param rootFolder the root folder to search
082         * @return the main class or {@code null}
083         * @throws IOException if the folder cannot be read
084         */
085        public static String findMainClass(File rootFolder) throws IOException {
086                return doWithMainClasses(rootFolder, new MainClassCallback<String>() {
087                        @Override
088                        public String doWith(MainClass mainClass) {
089                                return mainClass.getName();
090                        }
091                });
092        }
093
094        /**
095         * Find a single main class from the given {@code rootFolder}.
096         * @param rootFolder the root folder to search
097         * @return the main class or {@code null}
098         * @throws IOException if the folder cannot be read
099         */
100        public static String findSingleMainClass(File rootFolder) throws IOException {
101                return findSingleMainClass(rootFolder, null);
102        }
103
104        /**
105         * Find a single main class from the given {@code rootFolder}. A main class annotated
106         * with an annotation with the given {@code annotationName} will be preferred over a
107         * main class with no such annotation.
108         * @param rootFolder the root folder to search
109         * @param annotationName the name of the annotation that may be present on the main
110         * class
111         * @return the main class or {@code null}
112         * @throws IOException if the folder cannot be read
113         */
114        public static String findSingleMainClass(File rootFolder, String annotationName)
115                        throws IOException {
116                SingleMainClassCallback callback = new SingleMainClassCallback(annotationName);
117                MainClassFinder.doWithMainClasses(rootFolder, callback);
118                return callback.getMainClassName();
119        }
120
121        /**
122         * Find a single main class from the given {@code rootFolders}. A main class annotated
123         * with an annotation with the given {@code annotationName} will be preferred over a
124         * main class with no such annotation.
125         * @param rootFolders the root folders to search
126         * @param annotationName the name of the annotation that may be present on the main
127         * class
128         * @return the main class or {@code null}
129         * @throws IOException if a root folder cannot be read
130         * @since 1.5.5
131         */
132        public static String findSingleMainClass(Collection<File> rootFolders,
133                        String annotationName) throws IOException {
134                SingleMainClassCallback callback = new SingleMainClassCallback(annotationName);
135                doWithMainClasses(rootFolders, callback);
136                return callback.getMainClassName();
137        }
138
139        /**
140         * Perform the given callback operation on all main classes from the given root
141         * folder.
142         * @param <T> the result type
143         * @param rootFolders the root folders
144         * @param callback the callback
145         * @return the first callback result or {@code null}
146         * @throws IOException in case of I/O errors
147         */
148        static <T> T doWithMainClasses(Collection<File> rootFolders,
149                        MainClassCallback<T> callback) throws IOException {
150                for (File rootFolder : rootFolders) {
151                        T result = doWithMainClasses(rootFolder, callback);
152                        if (result != null) {
153                                return result;
154                        }
155                }
156                return null;
157        }
158
159        /**
160         * Perform the given callback operation on all main classes from the given root
161         * folder.
162         * @param <T> the result type
163         * @param rootFolder the root folder
164         * @param callback the callback
165         * @return the first callback result or {@code null}
166         * @throws IOException in case of I/O errors
167         */
168        static <T> T doWithMainClasses(File rootFolder, MainClassCallback<T> callback)
169                        throws IOException {
170                if (!rootFolder.exists()) {
171                        return null; // nothing to do
172                }
173                if (!rootFolder.isDirectory()) {
174                        throw new IllegalArgumentException(
175                                        "Invalid root folder '" + rootFolder + "'");
176                }
177                String prefix = rootFolder.getAbsolutePath() + "/";
178                Deque<File> stack = new ArrayDeque<File>();
179                stack.push(rootFolder);
180                while (!stack.isEmpty()) {
181                        File file = stack.pop();
182                        if (file.isFile()) {
183                                InputStream inputStream = new FileInputStream(file);
184                                try {
185                                        ClassDescriptor classDescriptor = createClassDescriptor(inputStream);
186                                        if (classDescriptor != null && classDescriptor.isMainMethodFound()) {
187                                                String className = convertToClassName(file.getAbsolutePath(),
188                                                                prefix);
189                                                T result = callback.doWith(new MainClass(className,
190                                                                classDescriptor.getAnnotationNames()));
191                                                if (result != null) {
192                                                        return result;
193                                                }
194                                        }
195                                }
196                                finally {
197                                        inputStream.close();
198                                }
199                        }
200                        if (file.isDirectory()) {
201                                pushAllSorted(stack, file.listFiles(PACKAGE_FOLDER_FILTER));
202                                pushAllSorted(stack, file.listFiles(CLASS_FILE_FILTER));
203                        }
204                }
205                return null;
206        }
207
208        private static void pushAllSorted(Deque<File> stack, File[] files) {
209                Arrays.sort(files, new Comparator<File>() {
210                        @Override
211                        public int compare(File o1, File o2) {
212                                return o1.getName().compareTo(o2.getName());
213                        }
214                });
215                for (File file : files) {
216                        stack.push(file);
217                }
218        }
219
220        /**
221         * Find the main class in a given jar file.
222         * @param jarFile the jar file to search
223         * @param classesLocation the location within the jar containing classes
224         * @return the main class or {@code null}
225         * @throws IOException if the jar file cannot be read
226         */
227        public static String findMainClass(JarFile jarFile, String classesLocation)
228                        throws IOException {
229                return doWithMainClasses(jarFile, classesLocation,
230                                new MainClassCallback<String>() {
231                                        @Override
232                                        public String doWith(MainClass mainClass) {
233                                                return mainClass.getName();
234                                        }
235                                });
236        }
237
238        /**
239         * Find a single main class in a given jar file.
240         * @param jarFile the jar file to search
241         * @param classesLocation the location within the jar containing classes
242         * @return the main class or {@code null}
243         * @throws IOException if the jar file cannot be read
244         */
245        public static String findSingleMainClass(JarFile jarFile, String classesLocation)
246                        throws IOException {
247                return findSingleMainClass(jarFile, classesLocation, null);
248        }
249
250        /**
251         * Find a single main class in a given jar file. A main class annotated with an
252         * annotation with the given {@code annotationName} will be preferred over a main
253         * class with no such annotation.
254         * @param jarFile the jar file to search
255         * @param classesLocation the location within the jar containing classes
256         * @param annotationName the name of the annotation that may be present on the main
257         * class
258         * @return the main class or {@code null}
259         * @throws IOException if the jar file cannot be read
260         */
261        public static String findSingleMainClass(JarFile jarFile, String classesLocation,
262                        String annotationName) throws IOException {
263                SingleMainClassCallback callback = new SingleMainClassCallback(annotationName);
264                MainClassFinder.doWithMainClasses(jarFile, classesLocation, callback);
265                return callback.getMainClassName();
266        }
267
268        /**
269         * Perform the given callback operation on all main classes from the given jar.
270         * @param <T> the result type
271         * @param jarFile the jar file to search
272         * @param classesLocation the location within the jar containing classes
273         * @param callback the callback
274         * @return the first callback result or {@code null}
275         * @throws IOException in case of I/O errors
276         */
277        static <T> T doWithMainClasses(JarFile jarFile, String classesLocation,
278                        MainClassCallback<T> callback) throws IOException {
279                List<JarEntry> classEntries = getClassEntries(jarFile, classesLocation);
280                Collections.sort(classEntries, new ClassEntryComparator());
281                for (JarEntry entry : classEntries) {
282                        InputStream inputStream = new BufferedInputStream(
283                                        jarFile.getInputStream(entry));
284                        try {
285                                ClassDescriptor classDescriptor = createClassDescriptor(inputStream);
286                                if (classDescriptor != null && classDescriptor.isMainMethodFound()) {
287                                        String className = convertToClassName(entry.getName(),
288                                                        classesLocation);
289                                        T result = callback.doWith(new MainClass(className,
290                                                        classDescriptor.getAnnotationNames()));
291                                        if (result != null) {
292                                                return result;
293                                        }
294                                }
295                        }
296                        finally {
297                                inputStream.close();
298                        }
299                }
300                return null;
301        }
302
303        private static String convertToClassName(String name, String prefix) {
304                name = name.replace('/', '.');
305                name = name.replace('\\', '.');
306                name = name.substring(0, name.length() - DOT_CLASS.length());
307                if (prefix != null) {
308                        name = name.substring(prefix.length());
309                }
310                return name;
311        }
312
313        private static List<JarEntry> getClassEntries(JarFile source,
314                        String classesLocation) {
315                classesLocation = (classesLocation != null ? classesLocation : "");
316                Enumeration<JarEntry> sourceEntries = source.entries();
317                List<JarEntry> classEntries = new ArrayList<JarEntry>();
318                while (sourceEntries.hasMoreElements()) {
319                        JarEntry entry = sourceEntries.nextElement();
320                        if (entry.getName().startsWith(classesLocation)
321                                        && entry.getName().endsWith(DOT_CLASS)) {
322                                classEntries.add(entry);
323                        }
324                }
325                return classEntries;
326        }
327
328        private static ClassDescriptor createClassDescriptor(InputStream inputStream) {
329                try {
330                        ClassReader classReader = new ClassReader(inputStream);
331                        ClassDescriptor classDescriptor = new ClassDescriptor();
332                        classReader.accept(classDescriptor, ClassReader.SKIP_CODE);
333                        return classDescriptor;
334                }
335                catch (IOException ex) {
336                        return null;
337                }
338        }
339
340        private static class ClassEntryComparator implements Comparator<JarEntry> {
341
342                @Override
343                public int compare(JarEntry o1, JarEntry o2) {
344                        Integer d1 = getDepth(o1);
345                        Integer d2 = getDepth(o2);
346                        int depthCompare = d1.compareTo(d2);
347                        if (depthCompare != 0) {
348                                return depthCompare;
349                        }
350                        return o1.getName().compareTo(o2.getName());
351                }
352
353                private int getDepth(JarEntry entry) {
354                        return entry.getName().split("/").length;
355                }
356
357        }
358
359        private static class ClassDescriptor extends ClassVisitor {
360
361                private final Set<String> annotationNames = new LinkedHashSet<String>();
362
363                private boolean mainMethodFound;
364
365                ClassDescriptor() {
366                        super(Opcodes.ASM4);
367                }
368
369                @Override
370                public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
371                        this.annotationNames.add(Type.getType(desc).getClassName());
372                        return null;
373                }
374
375                @Override
376                public MethodVisitor visitMethod(int access, String name, String desc,
377                                String signature, String[] exceptions) {
378                        if (isAccess(access, Opcodes.ACC_PUBLIC, Opcodes.ACC_STATIC)
379                                        && MAIN_METHOD_NAME.equals(name)
380                                        && MAIN_METHOD_TYPE.getDescriptor().equals(desc)) {
381                                this.mainMethodFound = true;
382                        }
383                        return null;
384                }
385
386                private boolean isAccess(int access, int... requiredOpsCodes) {
387                        for (int requiredOpsCode : requiredOpsCodes) {
388                                if ((access & requiredOpsCode) == 0) {
389                                        return false;
390                                }
391                        }
392                        return true;
393                }
394
395                boolean isMainMethodFound() {
396                        return this.mainMethodFound;
397                }
398
399                Set<String> getAnnotationNames() {
400                        return this.annotationNames;
401                }
402
403        }
404
405        /**
406         * Callback for handling {@link MainClass MainClasses}.
407         *
408         * @param <T> the callback's return type
409         */
410        interface MainClassCallback<T> {
411
412                /**
413                 * Handle the specified main class.
414                 * @param mainClass the main class
415                 * @return a non-null value if processing should end or {@code null} to continue
416                 */
417                T doWith(MainClass mainClass);
418
419        }
420
421        /**
422         * A class with a {@code main} method.
423         */
424        static final class MainClass {
425
426                private final String name;
427
428                private final Set<String> annotationNames;
429
430                /**
431                 * Creates a new {@code MainClass} rather represents the main class with the given
432                 * {@code name}. The class is annotated with the annotations with the given
433                 * {@code annotationNames}.
434                 * @param name the name of the class
435                 * @param annotationNames the names of the annotations on the class
436                 */
437                MainClass(String name, Set<String> annotationNames) {
438                        this.name = name;
439                        this.annotationNames = Collections
440                                        .unmodifiableSet(new HashSet<String>(annotationNames));
441                }
442
443                String getName() {
444                        return this.name;
445                }
446
447                Set<String> getAnnotationNames() {
448                        return this.annotationNames;
449                }
450
451                @Override
452                public String toString() {
453                        return this.name;
454                }
455
456                @Override
457                public int hashCode() {
458                        return this.name.hashCode();
459                }
460
461                @Override
462                public boolean equals(Object obj) {
463                        if (this == obj) {
464                                return true;
465                        }
466                        if (obj == null) {
467                                return false;
468                        }
469                        if (getClass() != obj.getClass()) {
470                                return false;
471                        }
472                        MainClass other = (MainClass) obj;
473                        if (!this.name.equals(other.name)) {
474                                return false;
475                        }
476                        return true;
477                }
478
479        }
480
481        /**
482         * Find a single main class, throwing an {@link IllegalStateException} if multiple
483         * candidates exist.
484         */
485        private static final class SingleMainClassCallback
486                        implements MainClassCallback<Object> {
487
488                private final Set<MainClass> mainClasses = new LinkedHashSet<MainClass>();
489
490                private final String annotationName;
491
492                private SingleMainClassCallback(String annotationName) {
493                        this.annotationName = annotationName;
494                }
495
496                @Override
497                public Object doWith(MainClass mainClass) {
498                        this.mainClasses.add(mainClass);
499                        return null;
500                }
501
502                private String getMainClassName() {
503                        Set<MainClass> matchingMainClasses = new LinkedHashSet<MainClass>();
504                        if (this.annotationName != null) {
505                                for (MainClass mainClass : this.mainClasses) {
506                                        if (mainClass.getAnnotationNames().contains(this.annotationName)) {
507                                                matchingMainClasses.add(mainClass);
508                                        }
509                                }
510                        }
511                        if (matchingMainClasses.isEmpty()) {
512                                matchingMainClasses.addAll(this.mainClasses);
513                        }
514                        if (matchingMainClasses.size() > 1) {
515                                throw new IllegalStateException(
516                                                "Unable to find a single main class from the following candidates "
517                                                                + matchingMainClasses);
518                        }
519                        return matchingMainClasses.isEmpty() ? null
520                                        : matchingMainClasses.iterator().next().getName();
521                }
522
523        }
524
525}