001/*
002 * Copyright 2012-2018 the original author or authors.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.springframework.boot.cli.compiler;
018
019import java.io.ByteArrayInputStream;
020import java.io.File;
021import java.io.IOException;
022import java.io.InputStream;
023import java.net.MalformedURLException;
024import java.net.URL;
025import java.net.URLClassLoader;
026import java.security.AccessController;
027import java.security.PrivilegedAction;
028import java.util.ArrayList;
029import java.util.Enumeration;
030import java.util.HashMap;
031import java.util.HashSet;
032import java.util.Map;
033import java.util.Set;
034
035import groovy.lang.GroovyClassLoader;
036import org.codehaus.groovy.ast.ClassNode;
037import org.codehaus.groovy.control.CompilationUnit;
038import org.codehaus.groovy.control.CompilerConfiguration;
039import org.codehaus.groovy.control.SourceUnit;
040
041import org.springframework.util.Assert;
042import org.springframework.util.FileCopyUtils;
043import org.springframework.util.StringUtils;
044
045/**
046 * Extension of the {@link GroovyClassLoader} with support for obtaining '.class' files as
047 * resources.
048 *
049 * @author Phillip Webb
050 * @author Dave Syer
051 */
052public class ExtendedGroovyClassLoader extends GroovyClassLoader {
053
054        private static final String SHARED_PACKAGE = "org.springframework.boot.groovy";
055
056        private static final URL[] NO_URLS = new URL[] {};
057
058        private final Map<String, byte[]> classResources = new HashMap<>();
059
060        private final GroovyCompilerScope scope;
061
062        private final CompilerConfiguration configuration;
063
064        public ExtendedGroovyClassLoader(GroovyCompilerScope scope) {
065                this(scope, createParentClassLoader(scope), new CompilerConfiguration());
066        }
067
068        private static ClassLoader createParentClassLoader(GroovyCompilerScope scope) {
069                ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
070                if (scope == GroovyCompilerScope.DEFAULT) {
071                        classLoader = new DefaultScopeParentClassLoader(classLoader);
072                }
073                return classLoader;
074        }
075
076        private ExtendedGroovyClassLoader(GroovyCompilerScope scope, ClassLoader parent,
077                        CompilerConfiguration configuration) {
078                super(parent, configuration);
079                this.configuration = configuration;
080                this.scope = scope;
081        }
082
083        @Override
084        protected Class<?> findClass(String name) throws ClassNotFoundException {
085                try {
086                        return super.findClass(name);
087                }
088                catch (ClassNotFoundException ex) {
089                        if (this.scope == GroovyCompilerScope.DEFAULT
090                                        && name.startsWith(SHARED_PACKAGE)) {
091                                Class<?> sharedClass = findSharedClass(name);
092                                if (sharedClass != null) {
093                                        return sharedClass;
094                                }
095                        }
096                        throw ex;
097                }
098        }
099
100        private Class<?> findSharedClass(String name) {
101                try {
102                        String path = name.replace('.', '/').concat(".class");
103                        try (InputStream inputStream = getParent().getResourceAsStream(path)) {
104                                if (inputStream != null) {
105                                        return defineClass(name, FileCopyUtils.copyToByteArray(inputStream));
106                                }
107                        }
108                        return null;
109                }
110                catch (Exception ex) {
111                        return null;
112                }
113        }
114
115        @Override
116        public InputStream getResourceAsStream(String name) {
117                InputStream resourceStream = super.getResourceAsStream(name);
118                if (resourceStream == null) {
119                        byte[] bytes = this.classResources.get(name);
120                        resourceStream = (bytes != null) ? new ByteArrayInputStream(bytes) : null;
121                }
122                return resourceStream;
123        }
124
125        @Override
126        public ClassCollector createCollector(CompilationUnit unit, SourceUnit su) {
127                InnerLoader loader = AccessController.doPrivileged(getInnerLoader());
128                return new ExtendedClassCollector(loader, unit, su);
129        }
130
131        private PrivilegedAction<InnerLoader> getInnerLoader() {
132                return () -> new InnerLoader(ExtendedGroovyClassLoader.this) {
133
134                        // Don't return URLs from the inner loader so that Tomcat only
135                        // searches the parent. Fixes 'TLD skipped' issues
136                        @Override
137                        public URL[] getURLs() {
138                                return NO_URLS;
139                        }
140
141                };
142        }
143
144        public CompilerConfiguration getConfiguration() {
145                return this.configuration;
146        }
147
148        /**
149         * Inner collector class used to track as classes are added.
150         */
151        protected class ExtendedClassCollector extends ClassCollector {
152
153                protected ExtendedClassCollector(InnerLoader loader, CompilationUnit unit,
154                                SourceUnit su) {
155                        super(loader, unit, su);
156                }
157
158                @Override
159                protected Class<?> createClass(byte[] code, ClassNode classNode) {
160                        Class<?> createdClass = super.createClass(code, classNode);
161                        ExtendedGroovyClassLoader.this.classResources
162                                        .put(classNode.getName().replace('.', '/') + ".class", code);
163                        return createdClass;
164                }
165
166        }
167
168        /**
169         * ClassLoader used for a parent that filters so that only classes from groovy-all.jar
170         * are exposed.
171         */
172        private static class DefaultScopeParentClassLoader extends ClassLoader {
173
174                private static final String[] GROOVY_JARS_PREFIXES = { "groovy", "antlr", "asm" };
175
176                private final URLClassLoader groovyOnlyClassLoader;
177
178                DefaultScopeParentClassLoader(ClassLoader parent) {
179                        super(parent);
180                        this.groovyOnlyClassLoader = new URLClassLoader(getGroovyJars(parent),
181                                        getClass().getClassLoader().getParent());
182                }
183
184                private URL[] getGroovyJars(ClassLoader parent) {
185                        Set<URL> urls = new HashSet<>();
186                        findGroovyJarsDirectly(parent, urls);
187                        if (urls.isEmpty()) {
188                                findGroovyJarsFromClassPath(urls);
189                        }
190                        Assert.state(!urls.isEmpty(), "Unable to find groovy JAR");
191                        return new ArrayList<>(urls).toArray(new URL[0]);
192                }
193
194                private void findGroovyJarsDirectly(ClassLoader classLoader, Set<URL> urls) {
195                        while (classLoader != null) {
196                                if (classLoader instanceof URLClassLoader) {
197                                        for (URL url : ((URLClassLoader) classLoader).getURLs()) {
198                                                if (isGroovyJar(url.toString())) {
199                                                        urls.add(url);
200                                                }
201                                        }
202                                }
203                                classLoader = classLoader.getParent();
204                        }
205                }
206
207                private void findGroovyJarsFromClassPath(Set<URL> urls) {
208                        String classpath = System.getProperty("java.class.path");
209                        String[] entries = classpath.split(System.getProperty("path.separator"));
210                        for (String entry : entries) {
211                                if (isGroovyJar(entry)) {
212                                        File file = new File(entry);
213                                        if (file.canRead()) {
214                                                try {
215                                                        urls.add(file.toURI().toURL());
216                                                }
217                                                catch (MalformedURLException ex) {
218                                                        // Swallow and continue
219                                                }
220                                        }
221                                }
222                        }
223                }
224
225                private boolean isGroovyJar(String entry) {
226                        entry = StringUtils.cleanPath(entry);
227                        for (String jarPrefix : GROOVY_JARS_PREFIXES) {
228                                if (entry.contains("/" + jarPrefix + "-")) {
229                                        return true;
230                                }
231                        }
232                        return false;
233                }
234
235                @Override
236                public Enumeration<URL> getResources(String name) throws IOException {
237                        return this.groovyOnlyClassLoader.getResources(name);
238                }
239
240                @Override
241                protected Class<?> loadClass(String name, boolean resolve)
242                                throws ClassNotFoundException {
243                        if (!name.startsWith("java.")) {
244                                this.groovyOnlyClassLoader.loadClass(name);
245                        }
246                        return super.loadClass(name, resolve);
247                }
248
249        }
250
251}