Home | History | Annotate | Download | only in mockitoutil
      1 /*
      2  * Copyright (c) 2017 Mockito contributors
      3  * This program is made available under the terms of the MIT License.
      4  */
      5 package org.mockitoutil;
      6 
      7 import java.io.ByteArrayInputStream;
      8 import java.io.File;
      9 import java.io.IOException;
     10 import java.io.InputStream;
     11 import java.lang.reflect.Field;
     12 import java.lang.reflect.Modifier;
     13 import java.net.MalformedURLException;
     14 import java.net.URI;
     15 import java.net.URISyntaxException;
     16 import java.net.URL;
     17 import java.net.URLClassLoader;
     18 import java.net.URLConnection;
     19 import java.net.URLStreamHandler;
     20 import java.util.ArrayList;
     21 import java.util.Arrays;
     22 import java.util.Collections;
     23 import java.util.Enumeration;
     24 import java.util.HashMap;
     25 import java.util.HashSet;
     26 import java.util.Iterator;
     27 import java.util.List;
     28 import java.util.Map;
     29 import java.util.Set;
     30 import java.util.concurrent.ExecutionException;
     31 import java.util.concurrent.ExecutorService;
     32 import java.util.concurrent.Executors;
     33 import java.util.concurrent.Future;
     34 import java.util.concurrent.ThreadFactory;
     35 import org.objenesis.Objenesis;
     36 import org.objenesis.ObjenesisStd;
     37 import org.objenesis.instantiator.ObjectInstantiator;
     38 
     39 import static java.lang.String.format;
     40 import static java.util.Arrays.asList;
     41 
     42 public abstract class ClassLoaders {
     43     protected ClassLoader parent = currentClassLoader();
     44 
     45     protected ClassLoaders() {
     46     }
     47 
     48     public static IsolatedURLClassLoaderBuilder isolatedClassLoader() {
     49         return new IsolatedURLClassLoaderBuilder();
     50     }
     51 
     52     public static ExcludingURLClassLoaderBuilder excludingClassLoader() {
     53         return new ExcludingURLClassLoaderBuilder();
     54     }
     55 
     56     public static InMemoryClassLoaderBuilder inMemoryClassLoader() {
     57         return new InMemoryClassLoaderBuilder();
     58     }
     59 
     60     public static ReachableClassesFinder in(ClassLoader classLoader) {
     61         return new ReachableClassesFinder(classLoader);
     62     }
     63 
     64     public static ClassLoader jdkClassLoader() {
     65         return String.class.getClassLoader();
     66     }
     67 
     68     public static ClassLoader systemClassLoader() {
     69         return ClassLoader.getSystemClassLoader();
     70     }
     71 
     72     public static ClassLoader currentClassLoader() {
     73         return ClassLoaders.class.getClassLoader();
     74     }
     75 
     76     public abstract ClassLoader build();
     77 
     78     public static Class<?>[] coverageTool() {
     79         HashSet<Class<?>> classes = new HashSet<Class<?>>();
     80         classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector"));
     81         classes.add(safeGetClass("org.slf4j.LoggerFactory"));
     82 
     83         classes.remove(null);
     84         return classes.toArray(new Class<?>[classes.size()]);
     85     }
     86 
     87     private static Class<?> safeGetClass(String className) {
     88         try {
     89             return Class.forName(className);
     90         } catch (ClassNotFoundException e) {
     91             return null;
     92         }
     93     }
     94 
     95     public static ClassLoaderExecutor using(final ClassLoader classLoader) {
     96         return new ClassLoaderExecutor(classLoader);
     97     }
     98 
     99     public static class ClassLoaderExecutor {
    100         private ClassLoader classLoader;
    101 
    102         public ClassLoaderExecutor(ClassLoader classLoader) {
    103             this.classLoader = classLoader;
    104         }
    105 
    106         public void execute(final Runnable task) throws Exception {
    107             ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() {
    108                 @Override
    109                 public Thread newThread(Runnable r) {
    110                     Thread thread = Executors.defaultThreadFactory().newThread(r);
    111                     thread.setContextClassLoader(classLoader);
    112                     return thread;
    113                 }
    114             });
    115             try {
    116                 Future<?> taskFuture = executorService.submit(new Runnable() {
    117                     @Override
    118                     public void run() {
    119                         try {
    120                             reloadTaskInClassLoader(task).run();
    121                         } catch (Throwable throwable) {
    122                             throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s",
    123                                                                    task,
    124                                                                    throwable.getMessage()),
    125                                                             throwable);
    126                         }
    127                     }
    128                 });
    129                 taskFuture.get();
    130                 executorService.shutdownNow();
    131             } catch (InterruptedException e) {
    132                 Thread.currentThread().interrupt();
    133             } catch (ExecutionException e) {
    134                 throw this.<Exception>unwrapAndThrows(e);
    135             }
    136         }
    137 
    138         @SuppressWarnings("unchecked")
    139         private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T {
    140             throw (T) ex.getCause();
    141         }
    142 
    143         Runnable reloadTaskInClassLoader(Runnable task) {
    144             try {
    145                 @SuppressWarnings("unchecked")
    146                 Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName());
    147 
    148                 Objenesis objenesis = new ObjenesisStd();
    149                 ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded);
    150                 Runnable reloaded = thingyInstantiator.newInstance();
    151 
    152                 // lenient shallow copy of class compatible fields
    153                 for (Field field : task.getClass().getDeclaredFields()) {
    154                     Field declaredField = taskClassReloaded.getDeclaredField(field.getName());
    155                     int modifiers = declaredField.getModifiers();
    156                     if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
    157                         // Skip static final fields (e.g. jacoco fields)
    158                         // otherwise IllegalAccessException (can be bypassed with Unsafe though)
    159                         // We may also miss coverage data.
    160                         continue;
    161                     }
    162                     if (declaredField.getType() == field.getType()) { // don't copy this
    163                         field.setAccessible(true);
    164                         declaredField.setAccessible(true);
    165                         declaredField.set(reloaded, field.get(task));
    166                     }
    167                 }
    168 
    169                 return reloaded;
    170             } catch (ClassNotFoundException e) {
    171                 throw new IllegalStateException(e);
    172             } catch (IllegalAccessException e) {
    173                 throw new IllegalStateException(e);
    174             } catch (NoSuchFieldException e) {
    175                 throw new IllegalStateException(e);
    176             }
    177         }
    178     }
    179 
    180     public static class IsolatedURLClassLoaderBuilder extends ClassLoaders {
    181         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
    182         private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>();
    183         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
    184 
    185         public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) {
    186             privateCopyPrefixes.addAll(asList(privatePrefixes));
    187             return this;
    188         }
    189 
    190         public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
    191             codeSourceUrls.addAll(pathsToURLs(urls));
    192             return this;
    193         }
    194 
    195         public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
    196             for (Class<?> clazz : classes) {
    197                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
    198             }
    199             return this;
    200         }
    201 
    202         public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() {
    203             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
    204             return this;
    205         }
    206 
    207         public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) {
    208             excludedPrefixes.addAll(asList(privatePrefixes));
    209             return this;
    210         }
    211 
    212         public ClassLoader build() {
    213             return new LocalIsolatedURLClassLoader(
    214                     jdkClassLoader(),
    215                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
    216                     privateCopyPrefixes,
    217                     excludedPrefixes
    218             );
    219         }
    220     }
    221 
    222     static class LocalIsolatedURLClassLoader extends URLClassLoader {
    223         private final ArrayList<String> privateCopyPrefixes;
    224         private final ArrayList<String> excludedPrefixes;
    225 
    226         LocalIsolatedURLClassLoader(ClassLoader classLoader,
    227                                     URL[] urls,
    228                                     ArrayList<String> privateCopyPrefixes,
    229                                     ArrayList<String> excludedPrefixes) {
    230             super(urls, classLoader);
    231             this.privateCopyPrefixes = privateCopyPrefixes;
    232             this.excludedPrefixes = excludedPrefixes;
    233         }
    234 
    235         @Override
    236         public Class<?> findClass(String name) throws ClassNotFoundException {
    237             if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) {
    238                 throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s",
    239                                                         privateCopyPrefixes,
    240                                                         excludedPrefixes));
    241             }
    242             try {
    243                 return super.findClass(name);
    244             } catch (ClassNotFoundException cnfe) {
    245                 throw new ClassNotFoundException(format("%s%n%s%n",
    246                                                         cnfe.getMessage(),
    247                                                         "    Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"),
    248                                                  cnfe);
    249             }
    250         }
    251 
    252         private boolean classShouldBePrivate(String name) {
    253             for (String prefix : privateCopyPrefixes) {
    254                 if (name.startsWith(prefix)) return true;
    255             }
    256             return false;
    257         }
    258 
    259         private boolean classShouldBeExcluded(String name) {
    260             for (String prefix : excludedPrefixes) {
    261                 if (name.startsWith(prefix)) return true;
    262             }
    263             return false;
    264         }
    265     }
    266 
    267     public static class ExcludingURLClassLoaderBuilder extends ClassLoaders {
    268         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
    269         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
    270 
    271         public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) {
    272             excludedPrefixes.addAll(asList(privatePrefixes));
    273             return this;
    274         }
    275 
    276         public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
    277             codeSourceUrls.addAll(pathsToURLs(urls));
    278             return this;
    279         }
    280 
    281         public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
    282             for (Class<?> clazz : classes) {
    283                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
    284             }
    285             return this;
    286         }
    287 
    288         public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() {
    289             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
    290             return this;
    291         }
    292 
    293         public ClassLoader build() {
    294             return new LocalExcludingURLClassLoader(
    295                     jdkClassLoader(),
    296                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
    297                     excludedPrefixes
    298             );
    299         }
    300     }
    301 
    302     static class LocalExcludingURLClassLoader extends URLClassLoader {
    303         private final ArrayList<String> excludedPrefixes;
    304 
    305         LocalExcludingURLClassLoader(ClassLoader classLoader,
    306                                      URL[] urls,
    307                                      ArrayList<String> excludedPrefixes) {
    308             super(urls, classLoader);
    309             this.excludedPrefixes = excludedPrefixes;
    310         }
    311 
    312         @Override
    313         public Class<?> findClass(String name) throws ClassNotFoundException {
    314             if (classShouldBeExcluded(name))
    315                 throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded");
    316             return super.findClass(name);
    317         }
    318 
    319         private boolean classShouldBeExcluded(String name) {
    320             for (String prefix : excludedPrefixes) {
    321                 if (name.startsWith(prefix)) return true;
    322             }
    323             return false;
    324         }
    325     }
    326 
    327     public static class InMemoryClassLoaderBuilder extends ClassLoaders {
    328         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
    329 
    330         public InMemoryClassLoaderBuilder withParent(ClassLoader parent) {
    331             this.parent = parent;
    332             return this;
    333         }
    334 
    335         public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) {
    336             inMemoryClassObjects.put(name, classDefinition);
    337             return this;
    338         }
    339 
    340         public ClassLoader build() {
    341             return new InMemoryClassLoader(parent, inMemoryClassObjects);
    342         }
    343     }
    344 
    345     static class InMemoryClassLoader extends ClassLoader {
    346         public static final String SCHEME = "mem";
    347         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
    348 
    349         public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) {
    350             super(parent);
    351             this.inMemoryClassObjects = inMemoryClassObjects;
    352         }
    353 
    354         protected Class<?> findClass(String name) throws ClassNotFoundException {
    355             byte[] classDefinition = inMemoryClassObjects.get(name);
    356             if (classDefinition != null) {
    357                 return defineClass(name, classDefinition, 0, classDefinition.length);
    358             }
    359             throw new ClassNotFoundException(name);
    360         }
    361 
    362         @Override
    363         public Enumeration<URL> getResources(String ignored) throws IOException {
    364             return inMemoryOnly();
    365         }
    366 
    367         private Enumeration<URL> inMemoryOnly() {
    368             final Set<String> names = inMemoryClassObjects.keySet();
    369             return new Enumeration<URL>() {
    370                 private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this);
    371                 private final Iterator<String> it = names.iterator();
    372 
    373                 public boolean hasMoreElements() {
    374                     return it.hasNext();
    375                 }
    376 
    377                 public URL nextElement() {
    378                     try {
    379                         return new URL(null, SCHEME + ":" + it.next(), memHandler);
    380                     } catch (MalformedURLException rethrown) {
    381                         throw new IllegalStateException(rethrown);
    382                     }
    383                 }
    384             };
    385         }
    386     }
    387 
    388     public static class MemHandler extends URLStreamHandler {
    389         private InMemoryClassLoader inMemoryClassLoader;
    390 
    391         public MemHandler(InMemoryClassLoader inMemoryClassLoader) {
    392             this.inMemoryClassLoader = inMemoryClassLoader;
    393         }
    394 
    395         @Override
    396         protected URLConnection openConnection(URL url) throws IOException {
    397             return new MemURLConnection(url, inMemoryClassLoader);
    398         }
    399 
    400         private static class MemURLConnection extends URLConnection {
    401             private final InMemoryClassLoader inMemoryClassLoader;
    402             private String qualifiedName;
    403 
    404             public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) {
    405                 super(url);
    406                 this.inMemoryClassLoader = inMemoryClassLoader;
    407                 qualifiedName = url.getPath();
    408             }
    409 
    410             @Override
    411             public void connect() throws IOException {
    412             }
    413 
    414             @Override
    415             public InputStream getInputStream() throws IOException {
    416                 return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName));
    417             }
    418         }
    419     }
    420 
    421     URL obtainCurrentClassPathOf(String className) {
    422         String path = className.replace('.', '/') + ".class";
    423         String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm();
    424 
    425         try {
    426             return new URL(url.substring(0, url.length() - path.length()));
    427         } catch (MalformedURLException e) {
    428             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
    429         }
    430     }
    431 
    432     List<URL> pathsToURLs(String... codeSourceUrls) {
    433         return pathsToURLs(Arrays.asList(codeSourceUrls));
    434     }
    435 
    436     private List<URL> pathsToURLs(List<String> codeSourceUrls) {
    437         ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size());
    438         for (String codeSourceUrl : codeSourceUrls) {
    439             URL url = pathToUrl(codeSourceUrl);
    440             urls.add(url);
    441         }
    442         return urls;
    443     }
    444 
    445     private URL pathToUrl(String path) {
    446         try {
    447             return new File(path).getAbsoluteFile().toURI().toURL();
    448         } catch (MalformedURLException e) {
    449             throw new IllegalArgumentException("Path is malformed", e);
    450         }
    451     }
    452 
    453     public static class ReachableClassesFinder {
    454         private ClassLoader classLoader;
    455         private Set<String> qualifiedNameSubstring = new HashSet<String>();
    456 
    457         ReachableClassesFinder(ClassLoader classLoader) {
    458             this.classLoader = classLoader;
    459         }
    460 
    461         public ReachableClassesFinder omit(String... qualifiedNameSubstring) {
    462             this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring));
    463             return this;
    464         }
    465 
    466         public Set<String> listOwnedClasses() throws IOException, URISyntaxException {
    467             Enumeration<URL> roots = classLoader.getResources("");
    468 
    469             Set<String> classes = new HashSet<String>();
    470             while (roots.hasMoreElements()) {
    471                 URI uri = roots.nextElement().toURI();
    472 
    473                 if (uri.getScheme().equalsIgnoreCase("file")) {
    474                     addFromFileBasedClassLoader(classes, uri);
    475                 } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) {
    476                     addFromInMemoryBasedClassLoader(classes, uri);
    477                 } else {
    478                     throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader));
    479                 }
    480             }
    481             return classes;
    482         }
    483 
    484         private void addFromFileBasedClassLoader(Set<String> classes, URI uri) {
    485             File root = new File(uri);
    486             classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring));
    487         }
    488 
    489         private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) {
    490             String qualifiedName = uri.getSchemeSpecificPart();
    491             if (excludes(qualifiedName, qualifiedNameSubstring)) {
    492                 classes.add(qualifiedName);
    493             }
    494         }
    495 
    496 
    497         private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) {
    498             if (file.isDirectory()) {
    499                 File[] files = file.listFiles();
    500                 Set<String> classes = new HashSet<String>();
    501                 for (File children : files) {
    502                     classes.addAll(findClassQualifiedNames(root, children, packageFilters));
    503                 }
    504                 return classes;
    505             } else {
    506                 if (file.getName().endsWith(".class")) {
    507                     String qualifiedName = classNameFor(root, file);
    508                     if (excludes(qualifiedName, packageFilters)) {
    509                         return Collections.singleton(qualifiedName);
    510                     }
    511                 }
    512             }
    513             return Collections.emptySet();
    514         }
    515 
    516         private boolean excludes(String qualifiedName, Set<String> packageFilters) {
    517             for (String filter : packageFilters) {
    518                 if (qualifiedName.contains(filter)) return false;
    519             }
    520             return true;
    521         }
    522 
    523         private String classNameFor(File root, File file) {
    524             String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1).
    525                     replace(File.separatorChar, '.');
    526             return temp.subSequence(0, temp.indexOf(".class")).toString();
    527         }
    528 
    529     }
    530 }
    531