Home | History | Annotate | Download | only in graph
      1 package org.testng.internal.thread.graph;
      2 
      3 import org.testng.TestNGException;
      4 import org.testng.collections.Lists;
      5 import org.testng.internal.DynamicGraph;
      6 import org.testng.internal.DynamicGraph.Status;
      7 
      8 import java.io.BufferedWriter;
      9 import java.io.File;
     10 import java.io.FileWriter;
     11 import java.io.IOException;
     12 import java.util.List;
     13 import java.util.concurrent.BlockingQueue;
     14 import java.util.concurrent.ThreadFactory;
     15 import java.util.concurrent.ThreadPoolExecutor;
     16 import java.util.concurrent.TimeUnit;
     17 
     18 /**
     19  * An Executor that launches tasks per batches. It takes a {@code DynamicGraph}
     20  * of tasks to be run and a {@code IThreadWorkerFactory} to initialize/create
     21  * {@code Runnable} wrappers around those tasks
     22  */
     23 public class GraphThreadPoolExecutor<T> extends ThreadPoolExecutor {
     24   private static final boolean DEBUG = false;
     25   /** Set to true if you want to generate GraphViz graphs */
     26   private static final boolean DOT_FILES = false;
     27 
     28   private DynamicGraph<T> m_graph;
     29   private List<Runnable> m_activeRunnables = Lists.newArrayList();
     30   private IThreadWorkerFactory<T> m_factory;
     31   private List<String> m_dotFiles = Lists.newArrayList();
     32   private int m_threadCount;
     33 
     34   public GraphThreadPoolExecutor(DynamicGraph<T> graph, IThreadWorkerFactory<T> factory, int corePoolSize,
     35       int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
     36     super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue
     37         /* , new TestNGThreadPoolFactory() */);
     38     ppp("Initializing executor with " + corePoolSize + " threads and following graph " + graph);
     39     m_threadCount = maximumPoolSize;
     40     m_graph = graph;
     41     m_factory = factory;
     42 
     43     if (m_graph.getFreeNodes().isEmpty()) {
     44       throw new TestNGException("The graph of methods contains a cycle:" + graph.getEdges());
     45     }
     46   }
     47 
     48   public void run() {
     49     synchronized(m_graph) {
     50       if (DOT_FILES) {
     51         m_dotFiles.add(m_graph.toDot());
     52       }
     53       List<T> freeNodes = m_graph.getFreeNodes();
     54       runNodes(freeNodes);
     55     }
     56   }
     57 
     58   /**
     59    * Create one worker per node and execute them.
     60    */
     61   private void runNodes(List<T> freeNodes) {
     62     List<IWorker<T>> runnables = m_factory.createWorkers(freeNodes);
     63     for (IWorker<T> r : runnables) {
     64       m_activeRunnables.add(r);
     65       ppp("Added to active runnable");
     66       setStatus(r, Status.RUNNING);
     67       ppp("Executing: " + r);
     68       try {
     69         execute(r);
     70 //        if (m_threadCount > 1) execute(r);
     71 //        else r.run();
     72       }
     73       catch(Exception ex) {
     74         ex.printStackTrace();
     75       }
     76     }
     77   }
     78 
     79   private void setStatus(IWorker<T> worker, Status status) {
     80     ppp("Set status:" + worker + " status:" + status);
     81     if (status == Status.FINISHED) {
     82       m_activeRunnables.remove(worker);
     83     }
     84     synchronized(m_graph) {
     85       for (T m : worker.getTasks()) {
     86         m_graph.setStatus(m, status);
     87       }
     88     }
     89   }
     90 
     91   @Override
     92   public void afterExecute(Runnable r, Throwable t) {
     93     ppp("Finished runnable:" + r);
     94     setStatus((IWorker<T>) r, Status.FINISHED);
     95     synchronized(m_graph) {
     96       ppp("Node count:" + m_graph.getNodeCount() + " and "
     97           + m_graph.getNodeCountWithStatus(Status.FINISHED) + " finished");
     98       if (m_graph.getNodeCount() == m_graph.getNodeCountWithStatus(Status.FINISHED)) {
     99         ppp("Shutting down executor " + this);
    100         if (DOT_FILES) {
    101           generateFiles(m_dotFiles);
    102         }
    103         shutdown();
    104       } else {
    105         if (DOT_FILES) {
    106           m_dotFiles.add(m_graph.toDot());
    107         }
    108         List<T> freeNodes = m_graph.getFreeNodes();
    109         runNodes(freeNodes);
    110       }
    111     }
    112 //    if (m_activeRunnables.isEmpty() && m_index < m_runnables.getSize()) {
    113 //      runNodes(m_index++);
    114 //    }
    115   }
    116 
    117   private void generateFiles(List<String> files) {
    118     try {
    119       File dir = File.createTempFile("TestNG-", "");
    120       dir.delete();
    121       dir.mkdir();
    122       for (int i = 0; i < files.size(); i++) {
    123         File f = new File(dir, "" + (i < 10 ? "0" : "") + i + ".dot");
    124         BufferedWriter bw = new BufferedWriter(new FileWriter(f));
    125         bw.append(files.get(i));
    126         bw.close();
    127       }
    128       if (DOT_FILES) {
    129         System.out.println("Created graph files in " + dir);
    130       }
    131     } catch(IOException ex) {
    132       ex.printStackTrace();
    133     }
    134   }
    135 
    136   private void ppp(String string) {
    137     if (DEBUG) {
    138       System.out.println("============ [GraphThreadPoolExecutor] " + Thread.currentThread().getId() + " "
    139           + string);
    140     }
    141   }
    142 
    143 }
    144 
    145 class TestNGThreadPoolFactory implements ThreadFactory {
    146   private int m_count = 0;
    147 
    148   @Override
    149   public Thread newThread(Runnable r) {
    150     Thread result = new Thread(r);
    151     result.setName("TestNG-" + m_count++);
    152     return result;
    153   }
    154 }
    155