Home | History | Annotate | Download | only in rs
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "rsContext.h"
     18 #include <time.h>
     19 
     20 using namespace android;
     21 using namespace android::renderscript;
     22 
     23 ScriptGroup::ScriptGroup(Context *rsc) : ObjectBase(rsc) {
     24 }
     25 
     26 ScriptGroup::~ScriptGroup() {
     27     if (mRSC->mHal.funcs.scriptgroup.destroy) {
     28         mRSC->mHal.funcs.scriptgroup.destroy(mRSC, this);
     29     }
     30 
     31     for (size_t ct=0; ct < mLinks.size(); ct++) {
     32         delete mLinks[ct];
     33     }
     34 }
     35 
     36 ScriptGroup::IO::IO(const ScriptKernelID *kid) {
     37     mKernel = kid;
     38 }
     39 
     40 ScriptGroup::Node::Node(Script *s) {
     41     mScript = s;
     42     mSeen = false;
     43     mOrder = 0;
     44 }
     45 
     46 ScriptGroup::Node * ScriptGroup::findNode(Script *s) const {
     47     //ALOGE("find %p   %i", s, (int)mNodes.size());
     48     for (size_t ct=0; ct < mNodes.size(); ct++) {
     49         Node *n = mNodes[ct];
     50         for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
     51             if (n->mKernels[ct2]->mScript == s) {
     52                 return n;
     53             }
     54         }
     55     }
     56     return NULL;
     57 }
     58 
     59 bool ScriptGroup::calcOrderRecurse(Node *n, int depth) {
     60     n->mSeen = true;
     61     if (n->mOrder < depth) {
     62         n->mOrder = depth;
     63     }
     64     bool ret = true;
     65     for (size_t ct=0; ct < n->mOutputs.size(); ct++) {
     66         const Link *l = n->mOutputs[ct];
     67         Node *nt = NULL;
     68         if (l->mDstField.get()) {
     69             nt = findNode(l->mDstField->mScript);
     70         } else {
     71             nt = findNode(l->mDstKernel->mScript);
     72         }
     73         if (nt->mSeen) {
     74             return false;
     75         }
     76         ret &= calcOrderRecurse(nt, n->mOrder + 1);
     77     }
     78     return ret;
     79 }
     80 
     81 #if !defined(RS_SERVER) && !defined(RS_COMPATIBILITY_LIB)
     82 static int CompareNodeForSort(ScriptGroup::Node *const* lhs,
     83                               ScriptGroup::Node *const* rhs) {
     84     if (lhs[0]->mOrder > rhs[0]->mOrder) {
     85         return 1;
     86     }
     87     return 0;
     88 }
     89 #else
     90 class NodeCompare {
     91 public:
     92     bool operator() (const ScriptGroup::Node* lhs,
     93                      const ScriptGroup::Node* rhs) {
     94         if (lhs->mOrder > rhs->mOrder) {
     95             return true;
     96         }
     97         return false;
     98     }
     99 };
    100 #endif
    101 
    102 bool ScriptGroup::calcOrder() {
    103     // Make nodes
    104     for (size_t ct=0; ct < mKernels.size(); ct++) {
    105         const ScriptKernelID *k = mKernels[ct].get();
    106         //ALOGE(" kernel %i, %p  s=%p", (int)ct, k, mKernels[ct]->mScript);
    107         Node *n = findNode(k->mScript);
    108         //ALOGE("    n = %p", n);
    109         if (n == NULL) {
    110             n = new Node(k->mScript);
    111             mNodes.add(n);
    112         }
    113         n->mKernels.add(k);
    114     }
    115 
    116     // add links
    117     //ALOGE("link count %i", (int)mLinks.size());
    118     for (size_t ct=0; ct < mLinks.size(); ct++) {
    119         Link *l = mLinks[ct];
    120         //ALOGE("link  %i %p", (int)ct, l);
    121         Node *n = findNode(l->mSource->mScript);
    122         //ALOGE("link n %p", n);
    123         n->mOutputs.add(l);
    124 
    125         if (l->mDstKernel.get()) {
    126             //ALOGE("l->mDstKernel.get() %p", l->mDstKernel.get());
    127             n = findNode(l->mDstKernel->mScript);
    128             //ALOGE("  n1 %p", n);
    129             n->mInputs.add(l);
    130         } else {
    131             n = findNode(l->mDstField->mScript);
    132             //ALOGE("  n2 %p", n);
    133             n->mInputs.add(l);
    134         }
    135     }
    136 
    137     //ALOGE("node count %i", (int)mNodes.size());
    138     // Order nodes
    139     bool ret = true;
    140     for (size_t ct=0; ct < mNodes.size(); ct++) {
    141         Node *n = mNodes[ct];
    142         if (n->mInputs.size() == 0) {
    143             for (size_t ct2=0; ct2 < mNodes.size(); ct2++) {
    144                 mNodes[ct2]->mSeen = false;
    145             }
    146             ret &= calcOrderRecurse(n, 0);
    147         }
    148     }
    149 
    150     for (size_t ct=0; ct < mKernels.size(); ct++) {
    151         const ScriptKernelID *k = mKernels[ct].get();
    152         const Node *n = findNode(k->mScript);
    153 
    154         if (k->mHasKernelOutput) {
    155             bool found = false;
    156             for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
    157                 if (n->mOutputs[ct2]->mSource.get() == k) {
    158                     found = true;
    159                     break;
    160                 }
    161             }
    162             if (!found) {
    163                 //ALOGE("add io out %p", k);
    164                 mOutputs.add(new IO(k));
    165             }
    166         }
    167 
    168         if (k->mHasKernelInput) {
    169             bool found = false;
    170             for (size_t ct2=0; ct2 < n->mInputs.size(); ct2++) {
    171                 if (n->mInputs[ct2]->mDstKernel.get() == k) {
    172                     found = true;
    173                     break;
    174                 }
    175             }
    176             if (!found) {
    177                 //ALOGE("add io in %p", k);
    178                 mInputs.add(new IO(k));
    179             }
    180         }
    181     }
    182 
    183     // sort
    184 #if !defined(RS_SERVER) && !defined(RS_COMPATIBILITY_LIB)
    185     mNodes.sort(&CompareNodeForSort);
    186 #else
    187     std::sort(mNodes.begin(), mNodes.end(), NodeCompare());
    188 #endif
    189 
    190     return ret;
    191 }
    192 
    193 ScriptGroup * ScriptGroup::create(Context *rsc,
    194                            ScriptKernelID ** kernels, size_t kernelsSize,
    195                            ScriptKernelID ** src, size_t srcSize,
    196                            ScriptKernelID ** dstK, size_t dstKSize,
    197                            ScriptFieldID  ** dstF, size_t dstFSize,
    198                            const Type ** type, size_t typeSize) {
    199 
    200     size_t kernelCount = kernelsSize / sizeof(ScriptKernelID *);
    201     size_t linkCount = typeSize / sizeof(Type *);
    202 
    203     //ALOGE("ScriptGroup::create kernels=%i  links=%i", (int)kernelCount, (int)linkCount);
    204 
    205 
    206     // Start by counting unique kernel sources
    207 
    208     ScriptGroup *sg = new ScriptGroup(rsc);
    209 
    210     sg->mKernels.reserve(kernelCount);
    211     for (size_t ct=0; ct < kernelCount; ct++) {
    212         sg->mKernels.add(kernels[ct]);
    213     }
    214 
    215     sg->mLinks.reserve(linkCount);
    216     for (size_t ct=0; ct < linkCount; ct++) {
    217         Link *l = new Link();
    218         l->mType = type[ct];
    219         l->mSource = src[ct];
    220         l->mDstField = dstF[ct];
    221         l->mDstKernel = dstK[ct];
    222         sg->mLinks.add(l);
    223     }
    224 
    225     sg->calcOrder();
    226 
    227     // allocate links
    228     for (size_t ct=0; ct < sg->mNodes.size(); ct++) {
    229         const Node *n = sg->mNodes[ct];
    230         for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
    231             Link *l = n->mOutputs[ct2];
    232             if (l->mAlloc.get()) {
    233                 continue;
    234             }
    235             const ScriptKernelID *k = l->mSource.get();
    236 
    237             Allocation * alloc = Allocation::createAllocation(rsc,
    238                     l->mType.get(), RS_ALLOCATION_USAGE_SCRIPT);
    239             l->mAlloc = alloc;
    240 
    241             for (size_t ct3=ct2+1; ct3 < n->mOutputs.size(); ct3++) {
    242                 if (n->mOutputs[ct3]->mSource.get() == l->mSource.get()) {
    243                     n->mOutputs[ct3]->mAlloc = alloc;
    244                 }
    245             }
    246         }
    247     }
    248 
    249     if (rsc->mHal.funcs.scriptgroup.init) {
    250         rsc->mHal.funcs.scriptgroup.init(rsc, sg);
    251     }
    252     sg->incUserRef();
    253     return sg;
    254 }
    255 
    256 void ScriptGroup::setInput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
    257     for (size_t ct=0; ct < mInputs.size(); ct++) {
    258         if (mInputs[ct]->mKernel == kid) {
    259             mInputs[ct]->mAlloc = a;
    260 
    261             if (rsc->mHal.funcs.scriptgroup.setInput) {
    262                 rsc->mHal.funcs.scriptgroup.setInput(rsc, this, kid, a);
    263             }
    264             return;
    265         }
    266     }
    267     rsAssert(!"ScriptGroup:setInput kid not found");
    268 }
    269 
    270 void ScriptGroup::setOutput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
    271     for (size_t ct=0; ct < mOutputs.size(); ct++) {
    272         if (mOutputs[ct]->mKernel == kid) {
    273             mOutputs[ct]->mAlloc = a;
    274 
    275             if (rsc->mHal.funcs.scriptgroup.setOutput) {
    276                 rsc->mHal.funcs.scriptgroup.setOutput(rsc, this, kid, a);
    277             }
    278             return;
    279         }
    280     }
    281     rsAssert(!"ScriptGroup:setOutput kid not found");
    282 }
    283 
    284 bool ScriptGroup::validateInputAndOutput(Context *rsc) {
    285     for(size_t i = 0; i < mInputs.size(); i++) {
    286         if (mInputs[i]->mAlloc.get() == NULL) {
    287             rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing input.");
    288             return false;
    289         }
    290     }
    291 
    292     for(size_t i = 0; i < mOutputs.size(); i++) {
    293         if (mOutputs[i]->mAlloc.get() == NULL) {
    294             rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing output.");
    295             return false;
    296         }
    297     }
    298 
    299     return true;
    300 }
    301 
    302 void ScriptGroup::execute(Context *rsc) {
    303 
    304     if (!validateInputAndOutput(rsc)) {
    305         return;
    306     }
    307 
    308     //ALOGE("ScriptGroup::execute");
    309     if (rsc->mHal.funcs.scriptgroup.execute) {
    310         rsc->mHal.funcs.scriptgroup.execute(rsc, this);
    311         return;
    312     }
    313 
    314     for (size_t ct=0; ct < mNodes.size(); ct++) {
    315         Node *n = mNodes[ct];
    316         //ALOGE("node %i, order %i, in %i out %i", (int)ct, n->mOrder, (int)n->mInputs.size(), (int)n->mOutputs.size());
    317 
    318         for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
    319             const ScriptKernelID *k = n->mKernels[ct2];
    320             Allocation *ain = NULL;
    321             Allocation *aout = NULL;
    322 
    323             for (size_t ct3=0; ct3 < n->mInputs.size(); ct3++) {
    324                 if (n->mInputs[ct3]->mDstKernel.get() == k) {
    325                     ain = n->mInputs[ct3]->mAlloc.get();
    326                     //ALOGE(" link in %p", ain);
    327                 }
    328             }
    329             for (size_t ct3=0; ct3 < mInputs.size(); ct3++) {
    330                 if (mInputs[ct3]->mKernel == k) {
    331                     ain = mInputs[ct3]->mAlloc.get();
    332                     //ALOGE(" io in %p", ain);
    333                 }
    334             }
    335 
    336             for (size_t ct3=0; ct3 < n->mOutputs.size(); ct3++) {
    337                 if (n->mOutputs[ct3]->mSource.get() == k) {
    338                     aout = n->mOutputs[ct3]->mAlloc.get();
    339                     //ALOGE(" link out %p", aout);
    340                 }
    341             }
    342             for (size_t ct3=0; ct3 < mOutputs.size(); ct3++) {
    343                 if (mOutputs[ct3]->mKernel == k) {
    344                     aout = mOutputs[ct3]->mAlloc.get();
    345                     //ALOGE(" io out %p", aout);
    346                 }
    347             }
    348 
    349             n->mScript->runForEach(rsc, k->mSlot, ain, aout, NULL, 0);
    350         }
    351 
    352     }
    353 
    354 }
    355 
    356 void ScriptGroup::serialize(Context *rsc, OStream *stream) const {
    357 }
    358 
    359 RsA3DClassID ScriptGroup::getClassId() const {
    360     return RS_A3D_CLASS_ID_SCRIPT_GROUP;
    361 }
    362 
    363 ScriptGroup::Link::Link() {
    364 }
    365 
    366 ScriptGroup::Link::~Link() {
    367 }
    368 
    369 namespace android {
    370 namespace renderscript {
    371 
    372 
    373 RsScriptGroup rsi_ScriptGroupCreate(Context *rsc,
    374                            RsScriptKernelID * kernels, size_t kernelsSize,
    375                            RsScriptKernelID * src, size_t srcSize,
    376                            RsScriptKernelID * dstK, size_t dstKSize,
    377                            RsScriptFieldID * dstF, size_t dstFSize,
    378                            const RsType * type, size_t typeSize) {
    379 
    380 
    381     return ScriptGroup::create(rsc,
    382                                (ScriptKernelID **) kernels, kernelsSize,
    383                                (ScriptKernelID **) src, srcSize,
    384                                (ScriptKernelID **) dstK, dstKSize,
    385                                (ScriptFieldID  **) dstF, dstFSize,
    386                                (const Type **) type, typeSize);
    387 }
    388 
    389 
    390 void rsi_ScriptGroupSetInput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
    391         RsAllocation alloc) {
    392     //ALOGE("rsi_ScriptGroupSetInput");
    393     ScriptGroup *s = (ScriptGroup *)sg;
    394     s->setInput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
    395 }
    396 
    397 void rsi_ScriptGroupSetOutput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
    398         RsAllocation alloc) {
    399     //ALOGE("rsi_ScriptGroupSetOutput");
    400     ScriptGroup *s = (ScriptGroup *)sg;
    401     s->setOutput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
    402 }
    403 
    404 void rsi_ScriptGroupExecute(Context *rsc, RsScriptGroup sg) {
    405     //ALOGE("rsi_ScriptGroupExecute");
    406     ScriptGroup *s = (ScriptGroup *)sg;
    407     s->execute(rsc);
    408 }
    409 
    410 }
    411 }
    412 
    413