Home | History | Annotate | Download | only in renderscript
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.renderscript;
     18 
     19 import java.lang.reflect.Method;
     20 import java.util.ArrayList;
     21 
     22 /**
     23  * ScriptGroup creates a groups of scripts which are executed
     24  * together based upon upon one execution call as if they were
     25  * all part of a single script.  The scripts may be connected
     26  * internally or to an external allocation. For the internal
     27  * connections the intermediate results are not observable after
     28  * the execution of the script.
     29  * <p>
     30  * The external connections are grouped into inputs and outputs.
     31  * All outputs are produced by a script kernel and placed into a
     32  * user supplied allocation. Inputs are similar but supply the
     33  * input of a kernal. Inputs bounds to a script are set directly
     34  * upon the script.
     35  * <p>
     36  * A ScriptGroup must contain at least one kernel. A ScriptGroup
     37  * must contain only a single directed acyclic graph (DAG) of
     38  * script kernels and connections. Attempting to create a
     39  * ScriptGroup with multiple DAGs or attempting to create
     40  * a cycle within a ScriptGroup will throw an exception.
     41  *
     42  **/
     43 public final class ScriptGroup extends BaseObj {
     44     IO mOutputs[];
     45     IO mInputs[];
     46 
     47     static class IO {
     48         Script.KernelID mKID;
     49         Allocation mAllocation;
     50 
     51         IO(Script.KernelID s) {
     52             mKID = s;
     53         }
     54     }
     55 
     56     static class ConnectLine {
     57         ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
     58             mFrom = from;
     59             mToK = to;
     60             mAllocationType = t;
     61         }
     62 
     63         ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
     64             mFrom = from;
     65             mToF = to;
     66             mAllocationType = t;
     67         }
     68 
     69         Script.FieldID mToF;
     70         Script.KernelID mToK;
     71         Script.KernelID mFrom;
     72         Type mAllocationType;
     73     }
     74 
     75     static class Node {
     76         Script mScript;
     77         ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
     78         ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
     79         ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
     80         int dagNumber;
     81 
     82         Node mNext;
     83 
     84         Node(Script s) {
     85             mScript = s;
     86         }
     87     }
     88 
     89 
     90     ScriptGroup(int id, RenderScript rs) {
     91         super(id, rs);
     92     }
     93 
     94     /**
     95      * Sets an input of the ScriptGroup. This specifies an
     96      * Allocation to be used for the kernels which require a kernel
     97      * input and that input is provided external to the group.
     98      *
     99      * @param s The ID of the kernel where the allocation should be
    100      *          connected.
    101      * @param a The allocation to connect.
    102      */
    103     public void setInput(Script.KernelID s, Allocation a) {
    104         for (int ct=0; ct < mInputs.length; ct++) {
    105             if (mInputs[ct].mKID == s) {
    106                 mInputs[ct].mAllocation = a;
    107                 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    108                 return;
    109             }
    110         }
    111         throw new RSIllegalArgumentException("Script not found");
    112     }
    113 
    114     /**
    115      * Sets an output of the ScriptGroup. This specifies an
    116      * Allocation to be used for the kernels which require a kernel
    117      * output and that output is provided external to the group.
    118      *
    119      * @param s The ID of the kernel where the allocation should be
    120      *          connected.
    121      * @param a The allocation to connect.
    122      */
    123     public void setOutput(Script.KernelID s, Allocation a) {
    124         for (int ct=0; ct < mOutputs.length; ct++) {
    125             if (mOutputs[ct].mKID == s) {
    126                 mOutputs[ct].mAllocation = a;
    127                 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    128                 return;
    129             }
    130         }
    131         throw new RSIllegalArgumentException("Script not found");
    132     }
    133 
    134     /**
    135      * Execute the ScriptGroup.  This will run all the kernels in
    136      * the script.  The state of the connecting lines will not be
    137      * observable after this operation.
    138      */
    139     public void execute() {
    140         mRS.nScriptGroupExecute(getID(mRS));
    141     }
    142 
    143 
    144     /**
    145      * Create a ScriptGroup. There are two steps to creating a
    146      * ScriptGoup.
    147      * <p>
    148      * First all the Kernels to be used by the group should be
    149      * added.  Once this is done the kernels should be connected.
    150      * Kernels cannot be added once a connection has been made.
    151      * <p>
    152      * Second, add connections. There are two forms of connections.
    153      * Kernel to Kernel and Kernel to Field. Kernel to Kernel is
    154      * higher performance and should be used where possible. The
    155      * line of connections cannot form a loop. If a loop is detected
    156      * an exception is thrown.
    157      * <p>
    158      * Once all the connections are made a call to create will
    159      * return the ScriptGroup object.
    160      *
    161      */
    162     public static final class Builder {
    163         private RenderScript mRS;
    164         private ArrayList<Node> mNodes = new ArrayList<Node>();
    165         private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
    166         private int mKernelCount;
    167 
    168         /**
    169          * Create a builder for generating a ScriptGroup.
    170          *
    171          *
    172          * @param rs The Renderscript context.
    173          */
    174         public Builder(RenderScript rs) {
    175             mRS = rs;
    176         }
    177 
    178         // do a DFS from original node, looking for original node
    179         // any cycle that could be created must contain original node
    180         private void validateCycle(Node target, Node original) {
    181             for (int ct = 0; ct < target.mOutputs.size(); ct++) {
    182                 final ConnectLine cl = target.mOutputs.get(ct);
    183                 if (cl.mToK != null) {
    184                     Node tn = findNode(cl.mToK.mScript);
    185                     if (tn.equals(original)) {
    186                         throw new RSInvalidStateException("Loops in group not allowed.");
    187                     }
    188                     validateCycle(tn, original);
    189                 }
    190                 if (cl.mToF != null) {
    191                     Node tn = findNode(cl.mToF.mScript);
    192                     if (tn.equals(original)) {
    193                         throw new RSInvalidStateException("Loops in group not allowed.");
    194                     }
    195                     validateCycle(tn, original);
    196                 }
    197             }
    198         }
    199 
    200         private void mergeDAGs(int valueUsed, int valueKilled) {
    201             for (int ct=0; ct < mNodes.size(); ct++) {
    202                 if (mNodes.get(ct).dagNumber == valueKilled)
    203                     mNodes.get(ct).dagNumber = valueUsed;
    204             }
    205         }
    206 
    207         private void validateDAGRecurse(Node n, int dagNumber) {
    208             // combine DAGs if this node has been seen already
    209             if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
    210                 mergeDAGs(n.dagNumber, dagNumber);
    211                 return;
    212             }
    213 
    214             n.dagNumber = dagNumber;
    215             for (int ct=0; ct < n.mOutputs.size(); ct++) {
    216                 final ConnectLine cl = n.mOutputs.get(ct);
    217                 if (cl.mToK != null) {
    218                     Node tn = findNode(cl.mToK.mScript);
    219                     validateDAGRecurse(tn, dagNumber);
    220                 }
    221                 if (cl.mToF != null) {
    222                     Node tn = findNode(cl.mToF.mScript);
    223                     validateDAGRecurse(tn, dagNumber);
    224                 }
    225             }
    226         }
    227 
    228         private void validateDAG() {
    229             for (int ct=0; ct < mNodes.size(); ct++) {
    230                 Node n = mNodes.get(ct);
    231                 if (n.mInputs.size() == 0) {
    232                     if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
    233                         throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
    234                     }
    235                     validateDAGRecurse(n, ct+1);
    236                 }
    237             }
    238             int dagNumber = mNodes.get(0).dagNumber;
    239             for (int ct=0; ct < mNodes.size(); ct++) {
    240                 if (mNodes.get(ct).dagNumber != dagNumber) {
    241                     throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
    242                 }
    243             }
    244         }
    245 
    246         private Node findNode(Script s) {
    247             for (int ct=0; ct < mNodes.size(); ct++) {
    248                 if (s == mNodes.get(ct).mScript) {
    249                     return mNodes.get(ct);
    250                 }
    251             }
    252             return null;
    253         }
    254 
    255         private Node findNode(Script.KernelID k) {
    256             for (int ct=0; ct < mNodes.size(); ct++) {
    257                 Node n = mNodes.get(ct);
    258                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    259                     if (k == n.mKernels.get(ct2)) {
    260                         return n;
    261                     }
    262                 }
    263             }
    264             return null;
    265         }
    266 
    267         /**
    268          * Adds a Kernel to the group.
    269          *
    270          *
    271          * @param k The kernel to add.
    272          *
    273          * @return Builder Returns this.
    274          */
    275         public Builder addKernel(Script.KernelID k) {
    276             if (mLines.size() != 0) {
    277                 throw new RSInvalidStateException(
    278                     "Kernels may not be added once connections exist.");
    279             }
    280 
    281             //android.util.Log.v("RSR", "addKernel 1 k=" + k);
    282             if (findNode(k) != null) {
    283                 return this;
    284             }
    285             //android.util.Log.v("RSR", "addKernel 2 ");
    286             mKernelCount++;
    287             Node n = findNode(k.mScript);
    288             if (n == null) {
    289                 //android.util.Log.v("RSR", "addKernel 3 ");
    290                 n = new Node(k.mScript);
    291                 mNodes.add(n);
    292             }
    293             n.mKernels.add(k);
    294             return this;
    295         }
    296 
    297         /**
    298          * Adds a connection to the group.
    299          *
    300          *
    301          * @param t The type of the connection. This is used to
    302          *          determine the kernel launch sizes on the source side
    303          *          of this connection.
    304          * @param from The source for the connection.
    305          * @param to The destination of the connection.
    306          *
    307          * @return Builder Returns this
    308          */
    309         public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
    310             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    311 
    312             Node nf = findNode(from);
    313             if (nf == null) {
    314                 throw new RSInvalidStateException("From script not found.");
    315             }
    316 
    317             Node nt = findNode(to.mScript);
    318             if (nt == null) {
    319                 throw new RSInvalidStateException("To script not found.");
    320             }
    321 
    322             ConnectLine cl = new ConnectLine(t, from, to);
    323             mLines.add(new ConnectLine(t, from, to));
    324 
    325             nf.mOutputs.add(cl);
    326             nt.mInputs.add(cl);
    327 
    328             validateCycle(nf, nf);
    329             return this;
    330         }
    331 
    332         /**
    333          * Adds a connection to the group.
    334          *
    335          *
    336          * @param t The type of the connection. This is used to
    337          *          determine the kernel launch sizes for both sides of
    338          *          this connection.
    339          * @param from The source for the connection.
    340          * @param to The destination of the connection.
    341          *
    342          * @return Builder Returns this
    343          */
    344         public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
    345             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    346 
    347             Node nf = findNode(from);
    348             if (nf == null) {
    349                 throw new RSInvalidStateException("From script not found.");
    350             }
    351 
    352             Node nt = findNode(to);
    353             if (nt == null) {
    354                 throw new RSInvalidStateException("To script not found.");
    355             }
    356 
    357             ConnectLine cl = new ConnectLine(t, from, to);
    358             mLines.add(new ConnectLine(t, from, to));
    359 
    360             nf.mOutputs.add(cl);
    361             nt.mInputs.add(cl);
    362 
    363             validateCycle(nf, nf);
    364             return this;
    365         }
    366 
    367 
    368 
    369         /**
    370          * Creates the Script group.
    371          *
    372          *
    373          * @return ScriptGroup The new ScriptGroup
    374          */
    375         public ScriptGroup create() {
    376 
    377             if (mNodes.size() == 0) {
    378                 throw new RSInvalidStateException("Empty script groups are not allowed");
    379             }
    380 
    381             // reset DAG numbers in case we're building a second group
    382             for (int ct=0; ct < mNodes.size(); ct++) {
    383                 mNodes.get(ct).dagNumber = 0;
    384             }
    385             validateDAG();
    386 
    387             ArrayList<IO> inputs = new ArrayList<IO>();
    388             ArrayList<IO> outputs = new ArrayList<IO>();
    389 
    390             int[] kernels = new int[mKernelCount];
    391             int idx = 0;
    392             for (int ct=0; ct < mNodes.size(); ct++) {
    393                 Node n = mNodes.get(ct);
    394                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    395                     final Script.KernelID kid = n.mKernels.get(ct2);
    396                     kernels[idx++] = kid.getID(mRS);
    397 
    398                     boolean hasInput = false;
    399                     boolean hasOutput = false;
    400                     for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
    401                         if (n.mInputs.get(ct3).mToK == kid) {
    402                             hasInput = true;
    403                         }
    404                     }
    405                     for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
    406                         if (n.mOutputs.get(ct3).mFrom == kid) {
    407                             hasOutput = true;
    408                         }
    409                     }
    410                     if (!hasInput) {
    411                         inputs.add(new IO(kid));
    412                     }
    413                     if (!hasOutput) {
    414                         outputs.add(new IO(kid));
    415                     }
    416 
    417                 }
    418             }
    419             if (idx != mKernelCount) {
    420                 throw new RSRuntimeException("Count mismatch, should not happen.");
    421             }
    422 
    423             int[] src = new int[mLines.size()];
    424             int[] dstk = new int[mLines.size()];
    425             int[] dstf = new int[mLines.size()];
    426             int[] types = new int[mLines.size()];
    427 
    428             for (int ct=0; ct < mLines.size(); ct++) {
    429                 ConnectLine cl = mLines.get(ct);
    430                 src[ct] = cl.mFrom.getID(mRS);
    431                 if (cl.mToK != null) {
    432                     dstk[ct] = cl.mToK.getID(mRS);
    433                 }
    434                 if (cl.mToF != null) {
    435                     dstf[ct] = cl.mToF.getID(mRS);
    436                 }
    437                 types[ct] = cl.mAllocationType.getID(mRS);
    438             }
    439 
    440             int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
    441             if (id == 0) {
    442                 throw new RSRuntimeException("Object creation error, should not happen.");
    443             }
    444 
    445             ScriptGroup sg = new ScriptGroup(id, mRS);
    446             sg.mOutputs = new IO[outputs.size()];
    447             for (int ct=0; ct < outputs.size(); ct++) {
    448                 sg.mOutputs[ct] = outputs.get(ct);
    449             }
    450 
    451             sg.mInputs = new IO[inputs.size()];
    452             for (int ct=0; ct < inputs.size(); ct++) {
    453                 sg.mInputs[ct] = inputs.get(ct);
    454             }
    455 
    456             return sg;
    457         }
    458 
    459     }
    460 
    461 
    462 }
    463 
    464 
    465