Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Callbacks.h"
     18 #include <android-base/logging.h>
     19 
     20 namespace android {
     21 namespace hardware {
     22 namespace neuralnetworks {
     23 namespace V1_2 {
     24 namespace implementation {
     25 
     26 CallbackBase::CallbackBase() : mNotified(false) {}
     27 
     28 CallbackBase::~CallbackBase() {
     29     // Note that we cannot call CallbackBase::join_thread from here:
     30     // CallbackBase is intended to be reference counted, and it is possible that
     31     // the reference count drops to zero in the bound thread, causing the
     32     // bound thread to call this destructor. If a thread tries to join
     33     // itself, it throws an exception, producing a message like the
     34     // following:
     35     //
     36     //     terminating with uncaught exception of type std::__1::system_error:
     37     //     thread::join failed: Resource deadlock would occur
     38 }
     39 
     40 void CallbackBase::wait() {
     41     std::unique_lock<std::mutex> lock(mMutex);
     42     mCondition.wait(lock, [this]{return mNotified;});
     43     join_thread_locked();
     44 }
     45 
     46 bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
     47     std::lock_guard<std::mutex> lock(mMutex);
     48     if (mPostWork != nullptr) {
     49         LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
     50                    "this callback object";
     51         return false;
     52     }
     53     if (post_work == nullptr) {
     54         LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
     55         return false;
     56     }
     57     mPostWork = std::move(post_work);
     58     return true;
     59 }
     60 
     61 bool CallbackBase::bind_thread(std::thread&& asyncThread) {
     62     std::lock_guard<std::mutex> lock(mMutex);
     63     if (mThread.joinable()) {
     64         LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
     65                    "callback object";
     66         return false;
     67     }
     68     if (!asyncThread.joinable()) {
     69         LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
     70         return false;
     71     }
     72     mThread = std::move(asyncThread);
     73     return true;
     74 }
     75 
     76 void CallbackBase::join_thread() {
     77     std::lock_guard<std::mutex> lock(mMutex);
     78     join_thread_locked();
     79 }
     80 
     81 void CallbackBase::notify() {
     82     {
     83         std::lock_guard<std::mutex> lock(mMutex);
     84         mNotified = true;
     85         if (mPostWork != nullptr) {
     86             bool success = mPostWork();
     87             if (!success) {
     88                 LOG(ERROR) << "CallbackBase::notify -- post work failed";
     89             }
     90         }
     91     }
     92     mCondition.notify_all();
     93 }
     94 
     95 void CallbackBase::join_thread_locked() {
     96     if (mThread.joinable()) {
     97         mThread.join();
     98     }
     99 }
    100 
    101 PreparedModelCallback::PreparedModelCallback() :
    102         mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
    103 
    104 PreparedModelCallback::~PreparedModelCallback() {}
    105 
    106 Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
    107                                            const sp<V1_0::IPreparedModel>& preparedModel) {
    108     mErrorStatus = errorStatus;
    109     mPreparedModel = preparedModel;
    110     CallbackBase::notify();
    111     return Void();
    112 }
    113 
    114 Return<void> PreparedModelCallback::notify_1_2(ErrorStatus errorStatus,
    115                                                const sp<V1_2::IPreparedModel>& preparedModel) {
    116     mErrorStatus = errorStatus;
    117     mPreparedModel = preparedModel;
    118     CallbackBase::notify();
    119     return Void();
    120 }
    121 
    122 ErrorStatus PreparedModelCallback::getStatus() {
    123     wait();
    124     return mErrorStatus;
    125 }
    126 
    127 sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() {
    128     wait();
    129     return mPreparedModel;
    130 }
    131 
    132 ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
    133 
    134 ExecutionCallback::~ExecutionCallback() {}
    135 
    136 Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
    137     mErrorStatus = errorStatus;
    138     mOutputShapes = {};
    139     mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
    140     CallbackBase::notify();
    141     return Void();
    142 }
    143 
    144 Return<void> ExecutionCallback::notify_1_2(ErrorStatus errorStatus,
    145                                            const hidl_vec<OutputShape>& outputShapes,
    146                                            const Timing& timing) {
    147     mErrorStatus = errorStatus;
    148     mOutputShapes = outputShapes;
    149     mTiming = timing;
    150     CallbackBase::notify();
    151     return Void();
    152 }
    153 
    154 ErrorStatus ExecutionCallback::getStatus() {
    155     wait();
    156     return mErrorStatus;
    157 }
    158 
    159 const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() {
    160     wait();
    161     return mOutputShapes;
    162 }
    163 
    164 Timing ExecutionCallback::getTiming() {
    165     wait();
    166     return mTiming;
    167 }
    168 
    169 }  // namespace implementation
    170 }  // namespace V1_2
    171 }  // namespace neuralnetworks
    172 }  // namespace hardware
    173 }  // namespace android
    174