Home | History | Annotate | Download | only in mockitoutil
      1 /*
      2  * Copyright (c) 2017 Mockito contributors
      3  * This program is made available under the terms of the MIT License.
      4  */
      5 package org.mockitoutil;
      6 
      7 import java.net.MalformedURLException;
      8 import java.net.URL;
      9 import java.net.URLClassLoader;
     10 import java.util.HashMap;
     11 import java.util.Map;
     12 import java.util.concurrent.Callable;
     13 
     14 /**
     15  * Custom classloader to load classes in hierarchic realm.
     16  *
     17  * Each class can be reloaded in the realm if the LoadClassPredicate says so.
     18  */
     19 public class SimplePerRealmReloadingClassLoader extends URLClassLoader {
     20 
     21     private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>();
     22     private ReloadClassPredicate reloadClassPredicate;
     23 
     24     public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) {
     25         super(getPossibleClassPathsUrls());
     26         this.reloadClassPredicate = reloadClassPredicate;
     27     }
     28 
     29     public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) {
     30         super(getPossibleClassPathsUrls(), parentClassLoader);
     31         this.reloadClassPredicate = reloadClassPredicate;
     32     }
     33 
     34     private static URL[] getPossibleClassPathsUrls() {
     35         return new URL[]{
     36                 obtainClassPath(),
     37                 obtainClassPath("org.mockito.Mockito"),
     38                 obtainClassPath("net.bytebuddy.ByteBuddy")
     39         };
     40     }
     41 
     42     private static URL obtainClassPath() {
     43         String className = SimplePerRealmReloadingClassLoader.class.getName();
     44         return obtainClassPath(className);
     45     }
     46 
     47     private static URL obtainClassPath(String className) {
     48         String path = className.replace('.', '/') + ".class";
     49         String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm();
     50 
     51         try {
     52             return new URL(url.substring(0, url.length() - path.length()));
     53         } catch (MalformedURLException e) {
     54             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
     55         }
     56     }
     57 
     58 
     59 
     60     @Override
     61     public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException {
     62         if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) {
     63             // return customLoadClass(qualifiedClassName);
     64 //            Class<?> loadedClass = findLoadedClass(qualifiedClassName);
     65             if(!classHashMap.containsKey(qualifiedClassName)) {
     66                 Class<?> foundClass = findClass(qualifiedClassName);
     67                 saveFoundClass(qualifiedClassName, foundClass);
     68                 return foundClass;
     69             }
     70 
     71             return classHashMap.get(qualifiedClassName);
     72         }
     73         return useParentClassLoaderFor(qualifiedClassName);
     74     }
     75 
     76     private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) {
     77         classHashMap.put(qualifiedClassName, foundClass);
     78     }
     79 
     80 
     81     private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException {
     82         return super.loadClass(qualifiedName);
     83     }
     84 
     85 
     86     public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception {
     87         ClassLoader current = Thread.currentThread().getContextClassLoader();
     88         try {
     89             Thread.currentThread().setContextClassLoader(this);
     90             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance();
     91             if (instance instanceof Callable) {
     92                 Callable<?> callableInRealm = (Callable<?>) instance;
     93                 return callableInRealm.call();
     94             }
     95         } finally {
     96             Thread.currentThread().setContextClassLoader(current);
     97         }
     98         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
     99     }
    100 
    101 
    102     public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception {
    103         ClassLoader current = Thread.currentThread().getContextClassLoader();
    104         try {
    105             Thread.currentThread().setContextClassLoader(this);
    106             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args);
    107             if (instance instanceof Callable) {
    108                 Callable<?> callableInRealm = (Callable<?>) instance;
    109                 return callableInRealm.call();
    110             }
    111         } finally {
    112             Thread.currentThread().setContextClassLoader(current);
    113         }
    114 
    115         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
    116     }
    117 
    118 
    119     public interface ReloadClassPredicate {
    120         boolean acceptReloadOf(String qualifiedName);
    121     }
    122 }
    123