Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Multinomial.h"
     18 
     19 #include "CpuExecutor.h"
     20 #include "CpuOperationUtils.h"
     21 #include "HalInterfaces.h"
     22 #include "Tracing.h"
     23 
     24 #include "guarded_philox_random.h"
     25 #include "philox_random.h"
     26 #include "simple_philox.h"
     27 
     28 #include "unsupported/Eigen/CXX11/Tensor"
     29 
     30 namespace android {
     31 namespace nn {
     32 
     33 namespace {
     34 
     35 template <typename T>
     36 inline T* GetBuffer(RunTimeOperandInfo* operand) {
     37     return reinterpret_cast<T*>(operand->buffer);
     38 }
     39 
     40 template <typename T>
     41 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
     42     return reinterpret_cast<const T*>(operand->buffer);
     43 }
     44 
     45 }  // namespace
     46 
     47 Multinomial::Multinomial(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
     48     NNTRACE_TRANS("Multinomial::Multinomial");
     49     input_ = GetInput(operation, operands, kInputTensor);
     50     sample_count_ = getScalarData<int>(*GetInput(operation, operands, kSampleCountParam));
     51     random_seeds_ = GetInput(operation, operands, kRandomSeedsTensor);
     52 
     53     output_ = GetOutput(operation, operands, kOutputTensor);
     54 }
     55 
     56 bool Multinomial::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
     57                           Shape* outputShape) {
     58     NNTRACE_TRANS("Multinomial::Prepare");
     59     NN_CHECK_EQ(NumInputsWithValues(operation, operands), 3);
     60     NN_CHECK_EQ(NumOutputs(operation), 1);
     61 
     62     const RunTimeOperandInfo* input = GetInput(operation, operands, Multinomial::kInputTensor);
     63     const Shape& inputShape = input->shape();
     64 
     65     const uint32_t batch_size = SizeOfDimension(input, 0);
     66     const uint32_t sample_count =
     67             getScalarData<int>(*GetInput(operation, operands, kSampleCountParam));
     68 
     69     outputShape->type = OperandType::TENSOR_INT32;
     70     outputShape->dimensions = {batch_size, sample_count};
     71     outputShape->offset = inputShape.offset;
     72     outputShape->scale = inputShape.scale;
     73 
     74     return true;
     75 }
     76 
     77 bool Multinomial::Eval() {
     78     NNTRACE_COMP("Multinomial::Eval");
     79     switch (input_->type) {
     80         case OperandType::TENSOR_FLOAT16: {
     81             std::vector<float> inputDataFloat32(getNumberOfElements(input_->shape()));
     82             convertFloat16ToFloat32(GetBuffer<_Float16>(input_), &inputDataFloat32);
     83             EvalFloat32(inputDataFloat32.data());
     84             break;
     85         }
     86         case OperandType::TENSOR_FLOAT32: {
     87             EvalFloat32(GetBuffer<float>(input_));
     88             break;
     89         }
     90         default: {
     91             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
     92             return false;
     93         }
     94     }
     95     return true;
     96 }
     97 
     98 void Multinomial::EvalFloat32(const float* inputData) {
     99     const int batch_size = SizeOfDimension(input_, 0);
    100     const int class_size = SizeOfDimension(input_, 1);
    101 
    102     tensorflow::GuardedPhiloxRandom random_generator;
    103     int32_t* seeds = GetBuffer<int32_t>(random_seeds_);
    104     random_generator.Init(seeds[0], seeds[1]);
    105 
    106     // PhiloxRandom produces results as 4 32-bit integers.
    107     int sample_count_aligned = (sample_count_ + 3) / 4 * 4;
    108     // The CPU operation uses 64-bit double values, so two results per sample.
    109     sample_count_aligned *= 2;
    110     auto random_generator_reserved =
    111             random_generator.ReserveRandomOutputs(batch_size * sample_count_aligned, 256);
    112     tensorflow::random::SimplePhilox simple_philox(&random_generator_reserved);
    113 
    114     for (uint64_t b = 0; b < batch_size; ++b) {
    115         const float* input_ptr_batch = inputData + b * class_size;
    116         float max = std::numeric_limits<float>::lowest();
    117         for (uint64_t j = 0; j < class_size; ++j) {
    118             if (Eigen::numext::isfinite(input_ptr_batch[j])) {
    119                 max = std::max(max, input_ptr_batch[j]);
    120             }
    121         }
    122         const double batch_max = static_cast<double>(max);
    123         double total = 0;
    124         std::vector<double> cdf;
    125         cdf.resize(class_size);
    126         for (uint64_t j = 0; j < class_size; ++j) {
    127             if (Eigen::numext::isfinite(static_cast<float>(input_ptr_batch[j]))) {
    128                 total += exp(static_cast<double>(input_ptr_batch[j]) - batch_max);
    129             }
    130             cdf[j] = total;
    131         }
    132 
    133         auto* output_ptr_batch = GetBuffer<int32_t>(output_) + b * sample_count_;
    134         for (uint64_t j = 0; j < sample_count_; ++j) {
    135             const double target = simple_philox.RandDouble() * total;
    136             auto found_iter = std::upper_bound(cdf.begin(), cdf.end(), target);
    137             output_ptr_batch[j] = std::distance(cdf.begin(), found_iter);
    138         }
    139     }
    140 }
    141 
    142 }  // namespace nn
    143 }  // namespace android
    144