1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.renderscript; 18 19 import java.lang.reflect.Method; 20 import java.util.ArrayList; 21 22 /** 23 * ScriptGroup creates a group of kernels that are executed 24 * together with one execution call as if they were a single kernel. 25 * The kernels may be connected internally or to an external allocation. 26 * The intermediate results for internal connections are not observable 27 * after the execution of the script. 28 * <p> 29 * External connections are grouped into inputs and outputs. 30 * All outputs are produced by a script kernel and placed into a 31 * user-supplied allocation. Inputs provide the input of a kernel. 32 * Inputs bound to script globals are set directly upon the script. 33 * <p> 34 * A ScriptGroup must contain at least one kernel. A ScriptGroup 35 * must contain only a single directed acyclic graph (DAG) of 36 * script kernels and connections. Attempting to create a 37 * ScriptGroup with multiple DAGs or attempting to create 38 * a cycle within a ScriptGroup will throw an exception. 39 * <p> 40 * Currently, all kernels in a ScriptGroup must be from separate 41 * Script objects. Attempting to use multiple kernels from the same 42 * Script object will result in an {@link android.renderscript.RSInvalidStateException}. 43 * 44 **/ 45 public final class ScriptGroup extends BaseObj { 46 IO mOutputs[]; 47 IO mInputs[]; 48 49 static class IO { 50 Script.KernelID mKID; 51 Allocation mAllocation; 52 53 IO(Script.KernelID s) { 54 mKID = s; 55 } 56 } 57 58 static class ConnectLine { 59 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) { 60 mFrom = from; 61 mToK = to; 62 mAllocationType = t; 63 } 64 65 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) { 66 mFrom = from; 67 mToF = to; 68 mAllocationType = t; 69 } 70 71 Script.FieldID mToF; 72 Script.KernelID mToK; 73 Script.KernelID mFrom; 74 Type mAllocationType; 75 } 76 77 static class Node { 78 Script mScript; 79 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); 80 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); 81 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); 82 int dagNumber; 83 84 Node mNext; 85 86 Node(Script s) { 87 mScript = s; 88 } 89 } 90 91 92 ScriptGroup(int id, RenderScript rs) { 93 super(id, rs); 94 } 95 96 /** 97 * Sets an input of the ScriptGroup. This specifies an 98 * Allocation to be used for kernels that require an input 99 * Allocation provided from outside of the ScriptGroup. 100 * 101 * @param s The ID of the kernel where the allocation should be 102 * connected. 103 * @param a The allocation to connect. 104 */ 105 public void setInput(Script.KernelID s, Allocation a) { 106 for (int ct=0; ct < mInputs.length; ct++) { 107 if (mInputs[ct].mKID == s) { 108 mInputs[ct].mAllocation = a; 109 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 110 return; 111 } 112 } 113 throw new RSIllegalArgumentException("Script not found"); 114 } 115 116 /** 117 * Sets an output of the ScriptGroup. This specifies an 118 * Allocation to be used for the kernels that require an output 119 * Allocation visible after the ScriptGroup is executed. 120 * 121 * @param s The ID of the kernel where the allocation should be 122 * connected. 123 * @param a The allocation to connect. 124 */ 125 public void setOutput(Script.KernelID s, Allocation a) { 126 for (int ct=0; ct < mOutputs.length; ct++) { 127 if (mOutputs[ct].mKID == s) { 128 mOutputs[ct].mAllocation = a; 129 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 130 return; 131 } 132 } 133 throw new RSIllegalArgumentException("Script not found"); 134 } 135 136 /** 137 * Execute the ScriptGroup. This will run all the kernels in 138 * the ScriptGroup. No internal connection results will be visible 139 * after execution of the ScriptGroup. 140 */ 141 public void execute() { 142 mRS.nScriptGroupExecute(getID(mRS)); 143 } 144 145 146 /** 147 * Helper class to build a ScriptGroup. A ScriptGroup is 148 * created in two steps. 149 * <p> 150 * First, all kernels to be used by the ScriptGroup should be added. 151 * <p> 152 * Second, add connections between kernels. There are two types 153 * of connections: kernel to kernel and kernel to field. 154 * Kernel to kernel allows a kernel's output to be passed to 155 * another kernel as input. Kernel to field allows the output of 156 * one kernel to be bound as a script global. Kernel to kernel is 157 * higher performance and should be used where possible. 158 * <p> 159 * A ScriptGroup must contain a single directed acyclic graph (DAG); it 160 * cannot contain cycles. Currently, all kernels used in a ScriptGroup 161 * must come from different Script objects. Additionally, all kernels 162 * in a ScriptGroup must have at least one input, output, or internal 163 * connection. 164 * <p> 165 * Once all connections are made, a call to {@link #create} will 166 * return the ScriptGroup object. 167 * 168 */ 169 public static final class Builder { 170 private RenderScript mRS; 171 private ArrayList<Node> mNodes = new ArrayList<Node>(); 172 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>(); 173 private int mKernelCount; 174 175 /** 176 * Create a Builder for generating a ScriptGroup. 177 * 178 * 179 * @param rs The RenderScript context. 180 */ 181 public Builder(RenderScript rs) { 182 mRS = rs; 183 } 184 185 // do a DFS from original node, looking for original node 186 // any cycle that could be created must contain original node 187 private void validateCycle(Node target, Node original) { 188 for (int ct = 0; ct < target.mOutputs.size(); ct++) { 189 final ConnectLine cl = target.mOutputs.get(ct); 190 if (cl.mToK != null) { 191 Node tn = findNode(cl.mToK.mScript); 192 if (tn.equals(original)) { 193 throw new RSInvalidStateException("Loops in group not allowed."); 194 } 195 validateCycle(tn, original); 196 } 197 if (cl.mToF != null) { 198 Node tn = findNode(cl.mToF.mScript); 199 if (tn.equals(original)) { 200 throw new RSInvalidStateException("Loops in group not allowed."); 201 } 202 validateCycle(tn, original); 203 } 204 } 205 } 206 207 private void mergeDAGs(int valueUsed, int valueKilled) { 208 for (int ct=0; ct < mNodes.size(); ct++) { 209 if (mNodes.get(ct).dagNumber == valueKilled) 210 mNodes.get(ct).dagNumber = valueUsed; 211 } 212 } 213 214 private void validateDAGRecurse(Node n, int dagNumber) { 215 // combine DAGs if this node has been seen already 216 if (n.dagNumber != 0 && n.dagNumber != dagNumber) { 217 mergeDAGs(n.dagNumber, dagNumber); 218 return; 219 } 220 221 n.dagNumber = dagNumber; 222 for (int ct=0; ct < n.mOutputs.size(); ct++) { 223 final ConnectLine cl = n.mOutputs.get(ct); 224 if (cl.mToK != null) { 225 Node tn = findNode(cl.mToK.mScript); 226 validateDAGRecurse(tn, dagNumber); 227 } 228 if (cl.mToF != null) { 229 Node tn = findNode(cl.mToF.mScript); 230 validateDAGRecurse(tn, dagNumber); 231 } 232 } 233 } 234 235 private void validateDAG() { 236 for (int ct=0; ct < mNodes.size(); ct++) { 237 Node n = mNodes.get(ct); 238 if (n.mInputs.size() == 0) { 239 if (n.mOutputs.size() == 0 && mNodes.size() > 1) { 240 throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); 241 } 242 validateDAGRecurse(n, ct+1); 243 } 244 } 245 int dagNumber = mNodes.get(0).dagNumber; 246 for (int ct=0; ct < mNodes.size(); ct++) { 247 if (mNodes.get(ct).dagNumber != dagNumber) { 248 throw new RSInvalidStateException("Multiple DAGs in group not allowed."); 249 } 250 } 251 } 252 253 private Node findNode(Script s) { 254 for (int ct=0; ct < mNodes.size(); ct++) { 255 if (s == mNodes.get(ct).mScript) { 256 return mNodes.get(ct); 257 } 258 } 259 return null; 260 } 261 262 private Node findNode(Script.KernelID k) { 263 for (int ct=0; ct < mNodes.size(); ct++) { 264 Node n = mNodes.get(ct); 265 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 266 if (k == n.mKernels.get(ct2)) { 267 return n; 268 } 269 } 270 } 271 return null; 272 } 273 274 /** 275 * Adds a Kernel to the group. 276 * 277 * 278 * @param k The kernel to add. 279 * 280 * @return Builder Returns this. 281 */ 282 public Builder addKernel(Script.KernelID k) { 283 if (mLines.size() != 0) { 284 throw new RSInvalidStateException( 285 "Kernels may not be added once connections exist."); 286 } 287 288 //android.util.Log.v("RSR", "addKernel 1 k=" + k); 289 if (findNode(k) != null) { 290 return this; 291 } 292 //android.util.Log.v("RSR", "addKernel 2 "); 293 mKernelCount++; 294 Node n = findNode(k.mScript); 295 if (n == null) { 296 //android.util.Log.v("RSR", "addKernel 3 "); 297 n = new Node(k.mScript); 298 mNodes.add(n); 299 } 300 n.mKernels.add(k); 301 return this; 302 } 303 304 /** 305 * Adds a connection to the group. 306 * 307 * 308 * @param t The type of the connection. This is used to 309 * determine the kernel launch sizes on the source side 310 * of this connection. 311 * @param from The source for the connection. 312 * @param to The destination of the connection. 313 * 314 * @return Builder Returns this 315 */ 316 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) { 317 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 318 319 Node nf = findNode(from); 320 if (nf == null) { 321 throw new RSInvalidStateException("From script not found."); 322 } 323 324 Node nt = findNode(to.mScript); 325 if (nt == null) { 326 throw new RSInvalidStateException("To script not found."); 327 } 328 329 ConnectLine cl = new ConnectLine(t, from, to); 330 mLines.add(new ConnectLine(t, from, to)); 331 332 nf.mOutputs.add(cl); 333 nt.mInputs.add(cl); 334 335 validateCycle(nf, nf); 336 return this; 337 } 338 339 /** 340 * Adds a connection to the group. 341 * 342 * 343 * @param t The type of the connection. This is used to 344 * determine the kernel launch sizes for both sides of 345 * this connection. 346 * @param from The source for the connection. 347 * @param to The destination of the connection. 348 * 349 * @return Builder Returns this 350 */ 351 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) { 352 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 353 354 Node nf = findNode(from); 355 if (nf == null) { 356 throw new RSInvalidStateException("From script not found."); 357 } 358 359 Node nt = findNode(to); 360 if (nt == null) { 361 throw new RSInvalidStateException("To script not found."); 362 } 363 364 ConnectLine cl = new ConnectLine(t, from, to); 365 mLines.add(new ConnectLine(t, from, to)); 366 367 nf.mOutputs.add(cl); 368 nt.mInputs.add(cl); 369 370 validateCycle(nf, nf); 371 return this; 372 } 373 374 375 376 /** 377 * Creates the Script group. 378 * 379 * 380 * @return ScriptGroup The new ScriptGroup 381 */ 382 public ScriptGroup create() { 383 384 if (mNodes.size() == 0) { 385 throw new RSInvalidStateException("Empty script groups are not allowed"); 386 } 387 388 // reset DAG numbers in case we're building a second group 389 for (int ct=0; ct < mNodes.size(); ct++) { 390 mNodes.get(ct).dagNumber = 0; 391 } 392 validateDAG(); 393 394 ArrayList<IO> inputs = new ArrayList<IO>(); 395 ArrayList<IO> outputs = new ArrayList<IO>(); 396 397 int[] kernels = new int[mKernelCount]; 398 int idx = 0; 399 for (int ct=0; ct < mNodes.size(); ct++) { 400 Node n = mNodes.get(ct); 401 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 402 final Script.KernelID kid = n.mKernels.get(ct2); 403 kernels[idx++] = kid.getID(mRS); 404 405 boolean hasInput = false; 406 boolean hasOutput = false; 407 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) { 408 if (n.mInputs.get(ct3).mToK == kid) { 409 hasInput = true; 410 } 411 } 412 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) { 413 if (n.mOutputs.get(ct3).mFrom == kid) { 414 hasOutput = true; 415 } 416 } 417 if (!hasInput) { 418 inputs.add(new IO(kid)); 419 } 420 if (!hasOutput) { 421 outputs.add(new IO(kid)); 422 } 423 424 } 425 } 426 if (idx != mKernelCount) { 427 throw new RSRuntimeException("Count mismatch, should not happen."); 428 } 429 430 int[] src = new int[mLines.size()]; 431 int[] dstk = new int[mLines.size()]; 432 int[] dstf = new int[mLines.size()]; 433 int[] types = new int[mLines.size()]; 434 435 for (int ct=0; ct < mLines.size(); ct++) { 436 ConnectLine cl = mLines.get(ct); 437 src[ct] = cl.mFrom.getID(mRS); 438 if (cl.mToK != null) { 439 dstk[ct] = cl.mToK.getID(mRS); 440 } 441 if (cl.mToF != null) { 442 dstf[ct] = cl.mToF.getID(mRS); 443 } 444 types[ct] = cl.mAllocationType.getID(mRS); 445 } 446 447 int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types); 448 if (id == 0) { 449 throw new RSRuntimeException("Object creation error, should not happen."); 450 } 451 452 ScriptGroup sg = new ScriptGroup(id, mRS); 453 sg.mOutputs = new IO[outputs.size()]; 454 for (int ct=0; ct < outputs.size(); ct++) { 455 sg.mOutputs[ct] = outputs.get(ct); 456 } 457 458 sg.mInputs = new IO[inputs.size()]; 459 for (int ct=0; ct < inputs.size(); ct++) { 460 sg.mInputs[ct] = inputs.get(ct); 461 } 462 463 return sg; 464 } 465 466 } 467 468 469 } 470 471 472