Home | History | Annotate | Download | only in renderscript
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.renderscript;
     18 
     19 import java.util.ArrayList;
     20 
     21 /**
     22  * ScriptGroup creates a group of kernels that are executed
     23  * together with one execution call as if they were a single kernel.
     24  * The kernels may be connected internally or to an external allocation.
     25  * The intermediate results for internal connections are not observable
     26  * after the execution of the script.
     27  * <p>
     28  * External connections are grouped into inputs and outputs.
     29  * All outputs are produced by a script kernel and placed into a
     30  * user-supplied allocation. Inputs provide the input of a kernel.
     31  * Inputs bound to script globals are set directly upon the script.
     32  * <p>
     33  * A ScriptGroup must contain at least one kernel. A ScriptGroup
     34  * must contain only a single directed acyclic graph (DAG) of
     35  * script kernels and connections. Attempting to create a
     36  * ScriptGroup with multiple DAGs or attempting to create
     37  * a cycle within a ScriptGroup will throw an exception.
     38  * <p>
     39  * Currently, all kernels in a ScriptGroup must be from separate
     40  * Script objects. Attempting to use multiple kernels from the same
     41  * Script object will result in an {@link android.renderscript.RSInvalidStateException}.
     42  *
     43  **/
     44 public final class ScriptGroup extends BaseObj {
     45     IO mOutputs[];
     46     IO mInputs[];
     47 
     48     static class IO {
     49         Script.KernelID mKID;
     50         Allocation mAllocation;
     51 
     52         IO(Script.KernelID s) {
     53             mKID = s;
     54         }
     55     }
     56 
     57     static class ConnectLine {
     58         ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
     59             mFrom = from;
     60             mToK = to;
     61             mAllocationType = t;
     62         }
     63 
     64         ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
     65             mFrom = from;
     66             mToF = to;
     67             mAllocationType = t;
     68         }
     69 
     70         Script.FieldID mToF;
     71         Script.KernelID mToK;
     72         Script.KernelID mFrom;
     73         Type mAllocationType;
     74     }
     75 
     76     static class Node {
     77         Script mScript;
     78         ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
     79         ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
     80         ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
     81         int dagNumber;
     82 
     83         Node mNext;
     84 
     85         Node(Script s) {
     86             mScript = s;
     87         }
     88     }
     89 
     90 
     91     ScriptGroup(long id, RenderScript rs) {
     92         super(id, rs);
     93     }
     94 
     95     /**
     96      * Sets an input of the ScriptGroup. This specifies an
     97      * Allocation to be used for kernels that require an input
     98      * Allocation provided from outside of the ScriptGroup.
     99      *
    100      * @param s The ID of the kernel where the allocation should be
    101      *          connected.
    102      * @param a The allocation to connect.
    103      */
    104     public void setInput(Script.KernelID s, Allocation a) {
    105         for (int ct=0; ct < mInputs.length; ct++) {
    106             if (mInputs[ct].mKID == s) {
    107                 mInputs[ct].mAllocation = a;
    108                 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    109                 return;
    110             }
    111         }
    112         throw new RSIllegalArgumentException("Script not found");
    113     }
    114 
    115     /**
    116      * Sets an output of the ScriptGroup. This specifies an
    117      * Allocation to be used for the kernels that require an output
    118      * Allocation visible after the ScriptGroup is executed.
    119      *
    120      * @param s The ID of the kernel where the allocation should be
    121      *          connected.
    122      * @param a The allocation to connect.
    123      */
    124     public void setOutput(Script.KernelID s, Allocation a) {
    125         for (int ct=0; ct < mOutputs.length; ct++) {
    126             if (mOutputs[ct].mKID == s) {
    127                 mOutputs[ct].mAllocation = a;
    128                 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
    129                 return;
    130             }
    131         }
    132         throw new RSIllegalArgumentException("Script not found");
    133     }
    134 
    135     /**
    136      * Execute the ScriptGroup.  This will run all the kernels in
    137      * the ScriptGroup.  No internal connection results will be visible
    138      * after execution of the ScriptGroup.
    139      */
    140     public void execute() {
    141         mRS.nScriptGroupExecute(getID(mRS));
    142     }
    143 
    144 
    145     /**
    146      * Helper class to build a ScriptGroup. A ScriptGroup is
    147      * created in two steps.
    148      * <p>
    149      * First, all kernels to be used by the ScriptGroup should be added.
    150      * <p>
    151      * Second, add connections between kernels. There are two types
    152      * of connections: kernel to kernel and kernel to field.
    153      * Kernel to kernel allows a kernel's output to be passed to
    154      * another kernel as input. Kernel to field allows the output of
    155      * one kernel to be bound as a script global. Kernel to kernel is
    156      * higher performance and should be used where possible.
    157      * <p>
    158      * A ScriptGroup must contain a single directed acyclic graph (DAG); it
    159      * cannot contain cycles. Currently, all kernels used in a ScriptGroup
    160      * must come from different Script objects. Additionally, all kernels
    161      * in a ScriptGroup must have at least one input, output, or internal
    162      * connection.
    163      * <p>
    164      * Once all connections are made, a call to {@link #create} will
    165      * return the ScriptGroup object.
    166      *
    167      */
    168     public static final class Builder {
    169         private RenderScript mRS;
    170         private ArrayList<Node> mNodes = new ArrayList<Node>();
    171         private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
    172         private int mKernelCount;
    173 
    174         /**
    175          * Create a Builder for generating a ScriptGroup.
    176          *
    177          *
    178          * @param rs The RenderScript context.
    179          */
    180         public Builder(RenderScript rs) {
    181             mRS = rs;
    182         }
    183 
    184         // do a DFS from original node, looking for original node
    185         // any cycle that could be created must contain original node
    186         private void validateCycle(Node target, Node original) {
    187             for (int ct = 0; ct < target.mOutputs.size(); ct++) {
    188                 final ConnectLine cl = target.mOutputs.get(ct);
    189                 if (cl.mToK != null) {
    190                     Node tn = findNode(cl.mToK.mScript);
    191                     if (tn.equals(original)) {
    192                         throw new RSInvalidStateException("Loops in group not allowed.");
    193                     }
    194                     validateCycle(tn, original);
    195                 }
    196                 if (cl.mToF != null) {
    197                     Node tn = findNode(cl.mToF.mScript);
    198                     if (tn.equals(original)) {
    199                         throw new RSInvalidStateException("Loops in group not allowed.");
    200                     }
    201                     validateCycle(tn, original);
    202                 }
    203             }
    204         }
    205 
    206         private void mergeDAGs(int valueUsed, int valueKilled) {
    207             for (int ct=0; ct < mNodes.size(); ct++) {
    208                 if (mNodes.get(ct).dagNumber == valueKilled)
    209                     mNodes.get(ct).dagNumber = valueUsed;
    210             }
    211         }
    212 
    213         private void validateDAGRecurse(Node n, int dagNumber) {
    214             // combine DAGs if this node has been seen already
    215             if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
    216                 mergeDAGs(n.dagNumber, dagNumber);
    217                 return;
    218             }
    219 
    220             n.dagNumber = dagNumber;
    221             for (int ct=0; ct < n.mOutputs.size(); ct++) {
    222                 final ConnectLine cl = n.mOutputs.get(ct);
    223                 if (cl.mToK != null) {
    224                     Node tn = findNode(cl.mToK.mScript);
    225                     validateDAGRecurse(tn, dagNumber);
    226                 }
    227                 if (cl.mToF != null) {
    228                     Node tn = findNode(cl.mToF.mScript);
    229                     validateDAGRecurse(tn, dagNumber);
    230                 }
    231             }
    232         }
    233 
    234         private void validateDAG() {
    235             for (int ct=0; ct < mNodes.size(); ct++) {
    236                 Node n = mNodes.get(ct);
    237                 if (n.mInputs.size() == 0) {
    238                     if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
    239                         throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
    240                     }
    241                     validateDAGRecurse(n, ct+1);
    242                 }
    243             }
    244             int dagNumber = mNodes.get(0).dagNumber;
    245             for (int ct=0; ct < mNodes.size(); ct++) {
    246                 if (mNodes.get(ct).dagNumber != dagNumber) {
    247                     throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
    248                 }
    249             }
    250         }
    251 
    252         private Node findNode(Script s) {
    253             for (int ct=0; ct < mNodes.size(); ct++) {
    254                 if (s == mNodes.get(ct).mScript) {
    255                     return mNodes.get(ct);
    256                 }
    257             }
    258             return null;
    259         }
    260 
    261         private Node findNode(Script.KernelID k) {
    262             for (int ct=0; ct < mNodes.size(); ct++) {
    263                 Node n = mNodes.get(ct);
    264                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    265                     if (k == n.mKernels.get(ct2)) {
    266                         return n;
    267                     }
    268                 }
    269             }
    270             return null;
    271         }
    272 
    273         /**
    274          * Adds a Kernel to the group.
    275          *
    276          *
    277          * @param k The kernel to add.
    278          *
    279          * @return Builder Returns this.
    280          */
    281         public Builder addKernel(Script.KernelID k) {
    282             if (mLines.size() != 0) {
    283                 throw new RSInvalidStateException(
    284                     "Kernels may not be added once connections exist.");
    285             }
    286 
    287             //android.util.Log.v("RSR", "addKernel 1 k=" + k);
    288             if (findNode(k) != null) {
    289                 return this;
    290             }
    291             //android.util.Log.v("RSR", "addKernel 2 ");
    292             mKernelCount++;
    293             Node n = findNode(k.mScript);
    294             if (n == null) {
    295                 //android.util.Log.v("RSR", "addKernel 3 ");
    296                 n = new Node(k.mScript);
    297                 mNodes.add(n);
    298             }
    299             n.mKernels.add(k);
    300             return this;
    301         }
    302 
    303         /**
    304          * Adds a connection to the group.
    305          *
    306          *
    307          * @param t The type of the connection. This is used to
    308          *          determine the kernel launch sizes on the source side
    309          *          of this connection.
    310          * @param from The source for the connection.
    311          * @param to The destination of the connection.
    312          *
    313          * @return Builder Returns this
    314          */
    315         public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
    316             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    317 
    318             Node nf = findNode(from);
    319             if (nf == null) {
    320                 throw new RSInvalidStateException("From script not found.");
    321             }
    322 
    323             Node nt = findNode(to.mScript);
    324             if (nt == null) {
    325                 throw new RSInvalidStateException("To script not found.");
    326             }
    327 
    328             ConnectLine cl = new ConnectLine(t, from, to);
    329             mLines.add(new ConnectLine(t, from, to));
    330 
    331             nf.mOutputs.add(cl);
    332             nt.mInputs.add(cl);
    333 
    334             validateCycle(nf, nf);
    335             return this;
    336         }
    337 
    338         /**
    339          * Adds a connection to the group.
    340          *
    341          *
    342          * @param t The type of the connection. This is used to
    343          *          determine the kernel launch sizes for both sides of
    344          *          this connection.
    345          * @param from The source for the connection.
    346          * @param to The destination of the connection.
    347          *
    348          * @return Builder Returns this
    349          */
    350         public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
    351             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
    352 
    353             Node nf = findNode(from);
    354             if (nf == null) {
    355                 throw new RSInvalidStateException("From script not found.");
    356             }
    357 
    358             Node nt = findNode(to);
    359             if (nt == null) {
    360                 throw new RSInvalidStateException("To script not found.");
    361             }
    362 
    363             ConnectLine cl = new ConnectLine(t, from, to);
    364             mLines.add(new ConnectLine(t, from, to));
    365 
    366             nf.mOutputs.add(cl);
    367             nt.mInputs.add(cl);
    368 
    369             validateCycle(nf, nf);
    370             return this;
    371         }
    372 
    373 
    374 
    375         /**
    376          * Creates the Script group.
    377          *
    378          *
    379          * @return ScriptGroup The new ScriptGroup
    380          */
    381         public ScriptGroup create() {
    382 
    383             if (mNodes.size() == 0) {
    384                 throw new RSInvalidStateException("Empty script groups are not allowed");
    385             }
    386 
    387             // reset DAG numbers in case we're building a second group
    388             for (int ct=0; ct < mNodes.size(); ct++) {
    389                 mNodes.get(ct).dagNumber = 0;
    390             }
    391             validateDAG();
    392 
    393             ArrayList<IO> inputs = new ArrayList<IO>();
    394             ArrayList<IO> outputs = new ArrayList<IO>();
    395 
    396             long[] kernels = new long[mKernelCount];
    397             int idx = 0;
    398             for (int ct=0; ct < mNodes.size(); ct++) {
    399                 Node n = mNodes.get(ct);
    400                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
    401                     final Script.KernelID kid = n.mKernels.get(ct2);
    402                     kernels[idx++] = kid.getID(mRS);
    403 
    404                     boolean hasInput = false;
    405                     boolean hasOutput = false;
    406                     for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
    407                         if (n.mInputs.get(ct3).mToK == kid) {
    408                             hasInput = true;
    409                         }
    410                     }
    411                     for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
    412                         if (n.mOutputs.get(ct3).mFrom == kid) {
    413                             hasOutput = true;
    414                         }
    415                     }
    416                     if (!hasInput) {
    417                         inputs.add(new IO(kid));
    418                     }
    419                     if (!hasOutput) {
    420                         outputs.add(new IO(kid));
    421                     }
    422 
    423                 }
    424             }
    425             if (idx != mKernelCount) {
    426                 throw new RSRuntimeException("Count mismatch, should not happen.");
    427             }
    428 
    429             long[] src = new long[mLines.size()];
    430             long[] dstk = new long[mLines.size()];
    431             long[] dstf = new long[mLines.size()];
    432             long[] types = new long[mLines.size()];
    433 
    434             for (int ct=0; ct < mLines.size(); ct++) {
    435                 ConnectLine cl = mLines.get(ct);
    436                 src[ct] = cl.mFrom.getID(mRS);
    437                 if (cl.mToK != null) {
    438                     dstk[ct] = cl.mToK.getID(mRS);
    439                 }
    440                 if (cl.mToF != null) {
    441                     dstf[ct] = cl.mToF.getID(mRS);
    442                 }
    443                 types[ct] = cl.mAllocationType.getID(mRS);
    444             }
    445 
    446             long id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
    447             if (id == 0) {
    448                 throw new RSRuntimeException("Object creation error, should not happen.");
    449             }
    450 
    451             ScriptGroup sg = new ScriptGroup(id, mRS);
    452             sg.mOutputs = new IO[outputs.size()];
    453             for (int ct=0; ct < outputs.size(); ct++) {
    454                 sg.mOutputs[ct] = outputs.get(ct);
    455             }
    456 
    457             sg.mInputs = new IO[inputs.size()];
    458             for (int ct=0; ct < inputs.size(); ct++) {
    459                 sg.mInputs[ct] = inputs.get(ct);
    460             }
    461 
    462             return sg;
    463         }
    464 
    465     }
    466 
    467 
    468 }
    469 
    470 
    471