Home | History | Annotate | Download | only in renderscript
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.renderscript;
     18 
     19 import java.lang.reflect.Method;
     20 import java.util.ArrayList;
     21 
     22 /**
     23  * ScriptGroup creates a group of kernels that are executed
     24  * together with one execution call as if they were a single kernel.
     25  * The kernels may be connected internally or to an external allocation.
     26  * The intermediate results for internal connections are not observable
     27  * after the execution of the script.
     28  * <p>
     29  * External connections are grouped into inputs and outputs.
     30  * All outputs are produced by a script kernel and placed into a
     31  * user-supplied allocation. Inputs provide the input of a kernel.
     32  * Inputs bound to script globals are set directly upon the script.
     33  * <p>
     34  * A ScriptGroup must contain at least one kernel. A ScriptGroup
     35  * must contain only a single directed acyclic graph (DAG) of
     36  * script kernels and connections. Attempting to create a
     37  * ScriptGroup with multiple DAGs or attempting to create
     38  * a cycle within a ScriptGroup will throw an exception.
     39  * <p>
     40  * Currently, all kernels in a ScriptGroup must be from separate
     41  * Script objects. Attempting to use multiple kernels from the same
     42  * Script object will result in an {@link android.renderscript.RSInvalidStateException}.
     43  *
     44  **/
     45 public final class ScriptGroup extends BaseObj {
     46     IO mOutputs[];
     47     IO mInputs[];
     48 
     49     static class IO {
     50         Script.KernelID mKID;
     51         Allocation mAllocation;
     52 
     53         IO(Script.KernelID s) {
     54             mKID = s;
     55         }
     56     }
     57 
     58     static class ConnectLine {
     59         ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
     60             mFrom = from;
     61             mToK = to;
     62             mAllocationType = t;
     63         }
     64 
     65         ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
     66             mFrom = from;
     67             mToF = to;
     68             mAllocationType = t;
     69         }
     70 
     71         Script.FieldID mToF;
     72         Script.KernelID mToK;
     73         Script.KernelID mFrom;
     74         Type mAllocationType;
     75     }
     76 
     77     static class Node {
     78         Script mScript;
     79         ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
     80         ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
     81         ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
     82         int dagNumber;
     83 
     84         Node mNext;
     85 
     86         Node(Script s) {
     87             mScript = s;
     88         }
     89     }
     90 
     91 
     92     ScriptGroup(int id, RenderScript rs) {
     93         super(id, rs);
     94     }
     95 
     96     /**
     97      * Sets an input of the ScriptGroup. This specifies an
     98      * Allocation to be used for kernels that require an input
     99      * Allocation provided from outside of the ScriptGroup.
    100      *
    101      * @param s The ID of the kernel where the allocation should be
    102      *          connected.
    103      * @param a The allocation to connect.
    104      */
    105     public void setInput(Script.KernelID s, Allocation a) {
    106         for (int ct=0; ct < mInputs.length; ct++) {
    107             if (mInputs[ct].mKID == s) {
    108                 mInputs[ct].mAllocation = a;
    109                 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    110                 return;
    111             }
    112         }
    113         throw new RSIllegalArgumentException("Script not found");
    114     }
    115 
    116     /**
    117      * Sets an output of the ScriptGroup. This specifies an
    118      * Allocation to be used for the kernels that require an output
    119      * Allocation visible after the ScriptGroup is executed.
    120      *
    121      * @param s The ID of the kernel where the allocation should be
    122      *          connected.
    123      * @param a The allocation to connect.
    124      */
    125     public void setOutput(Script.KernelID s, Allocation a) {
    126         for (int ct=0; ct < mOutputs.length; ct++) {
    127             if (mOutputs[ct].mKID == s) {
    128                 mOutputs[ct].mAllocation = a;
    129                 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    130                 return;
    131             }
    132         }
    133         throw new RSIllegalArgumentException("Script not found");
    134     }
    135 
    136     /**
    137      * Execute the ScriptGroup.  This will run all the kernels in
    138      * the ScriptGroup.  No internal connection results will be visible
    139      * after execution of the ScriptGroup.
    140      */
    141     public void execute() {
    142         mRS.nScriptGroupExecute(getID(mRS));
    143     }
    144 
    145 
    146     /**
    147      * Helper class to build a ScriptGroup. A ScriptGroup is
    148      * created in two steps.
    149      * <p>
    150      * First, all kernels to be used by the ScriptGroup should be added.
    151      * <p>
    152      * Second, add connections between kernels. There are two types
    153      * of connections: kernel to kernel and kernel to field.
    154      * Kernel to kernel allows a kernel's output to be passed to
    155      * another kernel as input. Kernel to field allows the output of
    156      * one kernel to be bound as a script global. Kernel to kernel is
    157      * higher performance and should be used where possible.
    158      * <p>
    159      * A ScriptGroup must contain a single directed acyclic graph (DAG); it
    160      * cannot contain cycles. Currently, all kernels used in a ScriptGroup
    161      * must come from different Script objects. Additionally, all kernels
    162      * in a ScriptGroup must have at least one input, output, or internal
    163      * connection.
    164      * <p>
    165      * Once all connections are made, a call to {@link #create} will
    166      * return the ScriptGroup object.
    167      *
    168      */
    169     public static final class Builder {
    170         private RenderScript mRS;
    171         private ArrayList<Node> mNodes = new ArrayList<Node>();
    172         private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
    173         private int mKernelCount;
    174 
    175         /**
    176          * Create a Builder for generating a ScriptGroup.
    177          *
    178          *
    179          * @param rs The RenderScript context.
    180          */
    181         public Builder(RenderScript rs) {
    182             mRS = rs;
    183         }
    184 
    185         // do a DFS from original node, looking for original node
    186         // any cycle that could be created must contain original node
    187         private void validateCycle(Node target, Node original) {
    188             for (int ct = 0; ct < target.mOutputs.size(); ct++) {
    189                 final ConnectLine cl = target.mOutputs.get(ct);
    190                 if (cl.mToK != null) {
    191                     Node tn = findNode(cl.mToK.mScript);
    192                     if (tn.equals(original)) {
    193                         throw new RSInvalidStateException("Loops in group not allowed.");
    194                     }
    195                     validateCycle(tn, original);
    196                 }
    197                 if (cl.mToF != null) {
    198                     Node tn = findNode(cl.mToF.mScript);
    199                     if (tn.equals(original)) {
    200                         throw new RSInvalidStateException("Loops in group not allowed.");
    201                     }
    202                     validateCycle(tn, original);
    203                 }
    204             }
    205         }
    206 
    207         private void mergeDAGs(int valueUsed, int valueKilled) {
    208             for (int ct=0; ct < mNodes.size(); ct++) {
    209                 if (mNodes.get(ct).dagNumber == valueKilled)
    210                     mNodes.get(ct).dagNumber = valueUsed;
    211             }
    212         }
    213 
    214         private void validateDAGRecurse(Node n, int dagNumber) {
    215             // combine DAGs if this node has been seen already
    216             if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
    217                 mergeDAGs(n.dagNumber, dagNumber);
    218                 return;
    219             }
    220 
    221             n.dagNumber = dagNumber;
    222             for (int ct=0; ct < n.mOutputs.size(); ct++) {
    223                 final ConnectLine cl = n.mOutputs.get(ct);
    224                 if (cl.mToK != null) {
    225                     Node tn = findNode(cl.mToK.mScript);
    226                     validateDAGRecurse(tn, dagNumber);
    227                 }
    228                 if (cl.mToF != null) {
    229                     Node tn = findNode(cl.mToF.mScript);
    230                     validateDAGRecurse(tn, dagNumber);
    231                 }
    232             }
    233         }
    234 
    235         private void validateDAG() {
    236             for (int ct=0; ct < mNodes.size(); ct++) {
    237                 Node n = mNodes.get(ct);
    238                 if (n.mInputs.size() == 0) {
    239                     if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
    240                         throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
    241                     }
    242                     validateDAGRecurse(n, ct+1);
    243                 }
    244             }
    245             int dagNumber = mNodes.get(0).dagNumber;
    246             for (int ct=0; ct < mNodes.size(); ct++) {
    247                 if (mNodes.get(ct).dagNumber != dagNumber) {
    248                     throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
    249                 }
    250             }
    251         }
    252 
    253         private Node findNode(Script s) {
    254             for (int ct=0; ct < mNodes.size(); ct++) {
    255                 if (s == mNodes.get(ct).mScript) {
    256                     return mNodes.get(ct);
    257                 }
    258             }
    259             return null;
    260         }
    261 
    262         private Node findNode(Script.KernelID k) {
    263             for (int ct=0; ct < mNodes.size(); ct++) {
    264                 Node n = mNodes.get(ct);
    265                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    266                     if (k == n.mKernels.get(ct2)) {
    267                         return n;
    268                     }
    269                 }
    270             }
    271             return null;
    272         }
    273 
    274         /**
    275          * Adds a Kernel to the group.
    276          *
    277          *
    278          * @param k The kernel to add.
    279          *
    280          * @return Builder Returns this.
    281          */
    282         public Builder addKernel(Script.KernelID k) {
    283             if (mLines.size() != 0) {
    284                 throw new RSInvalidStateException(
    285                     "Kernels may not be added once connections exist.");
    286             }
    287 
    288             //android.util.Log.v("RSR", "addKernel 1 k=" + k);
    289             if (findNode(k) != null) {
    290                 return this;
    291             }
    292             //android.util.Log.v("RSR", "addKernel 2 ");
    293             mKernelCount++;
    294             Node n = findNode(k.mScript);
    295             if (n == null) {
    296                 //android.util.Log.v("RSR", "addKernel 3 ");
    297                 n = new Node(k.mScript);
    298                 mNodes.add(n);
    299             }
    300             n.mKernels.add(k);
    301             return this;
    302         }
    303 
    304         /**
    305          * Adds a connection to the group.
    306          *
    307          *
    308          * @param t The type of the connection. This is used to
    309          *          determine the kernel launch sizes on the source side
    310          *          of this connection.
    311          * @param from The source for the connection.
    312          * @param to The destination of the connection.
    313          *
    314          * @return Builder Returns this
    315          */
    316         public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
    317             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    318 
    319             Node nf = findNode(from);
    320             if (nf == null) {
    321                 throw new RSInvalidStateException("From script not found.");
    322             }
    323 
    324             Node nt = findNode(to.mScript);
    325             if (nt == null) {
    326                 throw new RSInvalidStateException("To script not found.");
    327             }
    328 
    329             ConnectLine cl = new ConnectLine(t, from, to);
    330             mLines.add(new ConnectLine(t, from, to));
    331 
    332             nf.mOutputs.add(cl);
    333             nt.mInputs.add(cl);
    334 
    335             validateCycle(nf, nf);
    336             return this;
    337         }
    338 
    339         /**
    340          * Adds a connection to the group.
    341          *
    342          *
    343          * @param t The type of the connection. This is used to
    344          *          determine the kernel launch sizes for both sides of
    345          *          this connection.
    346          * @param from The source for the connection.
    347          * @param to The destination of the connection.
    348          *
    349          * @return Builder Returns this
    350          */
    351         public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
    352             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    353 
    354             Node nf = findNode(from);
    355             if (nf == null) {
    356                 throw new RSInvalidStateException("From script not found.");
    357             }
    358 
    359             Node nt = findNode(to);
    360             if (nt == null) {
    361                 throw new RSInvalidStateException("To script not found.");
    362             }
    363 
    364             ConnectLine cl = new ConnectLine(t, from, to);
    365             mLines.add(new ConnectLine(t, from, to));
    366 
    367             nf.mOutputs.add(cl);
    368             nt.mInputs.add(cl);
    369 
    370             validateCycle(nf, nf);
    371             return this;
    372         }
    373 
    374 
    375 
    376         /**
    377          * Creates the Script group.
    378          *
    379          *
    380          * @return ScriptGroup The new ScriptGroup
    381          */
    382         public ScriptGroup create() {
    383 
    384             if (mNodes.size() == 0) {
    385                 throw new RSInvalidStateException("Empty script groups are not allowed");
    386             }
    387 
    388             // reset DAG numbers in case we're building a second group
    389             for (int ct=0; ct < mNodes.size(); ct++) {
    390                 mNodes.get(ct).dagNumber = 0;
    391             }
    392             validateDAG();
    393 
    394             ArrayList<IO> inputs = new ArrayList<IO>();
    395             ArrayList<IO> outputs = new ArrayList<IO>();
    396 
    397             int[] kernels = new int[mKernelCount];
    398             int idx = 0;
    399             for (int ct=0; ct < mNodes.size(); ct++) {
    400                 Node n = mNodes.get(ct);
    401                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    402                     final Script.KernelID kid = n.mKernels.get(ct2);
    403                     kernels[idx++] = kid.getID(mRS);
    404 
    405                     boolean hasInput = false;
    406                     boolean hasOutput = false;
    407                     for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
    408                         if (n.mInputs.get(ct3).mToK == kid) {
    409                             hasInput = true;
    410                         }
    411                     }
    412                     for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
    413                         if (n.mOutputs.get(ct3).mFrom == kid) {
    414                             hasOutput = true;
    415                         }
    416                     }
    417                     if (!hasInput) {
    418                         inputs.add(new IO(kid));
    419                     }
    420                     if (!hasOutput) {
    421                         outputs.add(new IO(kid));
    422                     }
    423 
    424                 }
    425             }
    426             if (idx != mKernelCount) {
    427                 throw new RSRuntimeException("Count mismatch, should not happen.");
    428             }
    429 
    430             int[] src = new int[mLines.size()];
    431             int[] dstk = new int[mLines.size()];
    432             int[] dstf = new int[mLines.size()];
    433             int[] types = new int[mLines.size()];
    434 
    435             for (int ct=0; ct < mLines.size(); ct++) {
    436                 ConnectLine cl = mLines.get(ct);
    437                 src[ct] = cl.mFrom.getID(mRS);
    438                 if (cl.mToK != null) {
    439                     dstk[ct] = cl.mToK.getID(mRS);
    440                 }
    441                 if (cl.mToF != null) {
    442                     dstf[ct] = cl.mToF.getID(mRS);
    443                 }
    444                 types[ct] = cl.mAllocationType.getID(mRS);
    445             }
    446 
    447             int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
    448             if (id == 0) {
    449                 throw new RSRuntimeException("Object creation error, should not happen.");
    450             }
    451 
    452             ScriptGroup sg = new ScriptGroup(id, mRS);
    453             sg.mOutputs = new IO[outputs.size()];
    454             for (int ct=0; ct < outputs.size(); ct++) {
    455                 sg.mOutputs[ct] = outputs.get(ct);
    456             }
    457 
    458             sg.mInputs = new IO[inputs.size()];
    459             for (int ct=0; ct < inputs.size(); ct++) {
    460                 sg.mInputs[ct] = inputs.get(ct);
    461             }
    462 
    463             return sg;
    464         }
    465 
    466     }
    467 
    468 
    469 }
    470 
    471 
    472