Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Multinomial.h"
     18 
     19 #include "HalInterfaces.h"
     20 #include "NeuralNetworksWrapper.h"
     21 #include "gmock/gmock-matchers.h"
     22 #include "gtest/gtest.h"
     23 #include "philox_random.h"
     24 #include "simple_philox.h"
     25 
     26 #include "unsupported/Eigen/CXX11/Tensor"
     27 
     28 namespace android {
     29 namespace nn {
     30 namespace wrapper {
     31 
     32 using ::testing::FloatNear;
     33 
     34 constexpr int kFixedRandomSeed1 = 37;
     35 constexpr int kFixedRandomSeed2 = 42;
     36 
     37 class MultinomialOpModel {
     38    public:
     39     MultinomialOpModel(uint32_t batch_size, uint32_t class_size, uint32_t sample_size)
     40         : batch_size_(batch_size), class_size_(class_size), sample_size_(sample_size) {
     41         std::vector<uint32_t> inputs;
     42         OperandType logitsType(Type::TENSOR_FLOAT32, {batch_size_, class_size_});
     43         inputs.push_back(model_.addOperand(&logitsType));
     44         OperandType samplesType(Type::INT32, {});
     45         inputs.push_back(model_.addOperand(&samplesType));
     46         OperandType seedsType(Type::TENSOR_INT32, {2});
     47         inputs.push_back(model_.addOperand(&seedsType));
     48 
     49         std::vector<uint32_t> outputs;
     50         OperandType outputType(Type::TENSOR_INT32, {batch_size_, sample_size_});
     51         outputs.push_back(model_.addOperand(&outputType));
     52 
     53         model_.addOperation(ANEURALNETWORKS_RANDOM_MULTINOMIAL, inputs, outputs);
     54         model_.identifyInputsAndOutputs(inputs, outputs);
     55         model_.finish();
     56     }
     57 
     58     void Invoke() {
     59         ASSERT_TRUE(model_.isValid());
     60 
     61         Compilation compilation(&model_);
     62         compilation.finish();
     63         Execution execution(&compilation);
     64 
     65         tensorflow::random::PhiloxRandom rng(kFixedRandomSeed1);
     66         tensorflow::random::SimplePhilox srng(&rng);
     67         const int sample_count = batch_size_ * class_size_;
     68         for (int i = 0; i < sample_count; ++i) {
     69             input_.push_back(srng.RandDouble());
     70         }
     71         ASSERT_EQ(execution.setInput(Multinomial::kInputTensor, input_.data(),
     72                                      sizeof(float) * input_.size()),
     73                   Result::NO_ERROR);
     74         ASSERT_EQ(execution.setInput(Multinomial::kSampleCountParam, &sample_size_,
     75                                      sizeof(sample_size_)),
     76                   Result::NO_ERROR);
     77 
     78         std::vector<uint32_t> seeds{kFixedRandomSeed1, kFixedRandomSeed2};
     79         ASSERT_EQ(execution.setInput(Multinomial::kRandomSeedsTensor, seeds.data(),
     80                                      sizeof(uint32_t) * seeds.size()),
     81                   Result::NO_ERROR);
     82 
     83         output_.insert(output_.end(), batch_size_ * sample_size_, 0);
     84         ASSERT_EQ(execution.setOutput(Multinomial::kOutputTensor, output_.data(),
     85                                       sizeof(uint32_t) * output_.size()),
     86                   Result::NO_ERROR);
     87 
     88         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
     89     }
     90 
     91     const std::vector<float>& GetInput() const { return input_; }
     92     const std::vector<uint32_t>& GetOutput() const { return output_; }
     93 
     94    private:
     95     Model model_;
     96 
     97     const uint32_t batch_size_;
     98     const uint32_t class_size_;
     99     const uint32_t sample_size_;
    100 
    101     std::vector<float> input_;
    102     std::vector<uint32_t> output_;
    103 };
    104 
    105 TEST(MultinomialOpTest, ProbabilityDeltaWithinTolerance) {
    106     constexpr int kBatchSize = 8;
    107     constexpr int kNumClasses = 10000;
    108     constexpr int kNumSamples = 128;
    109     constexpr float kMaxProbabilityDelta = 0.025;
    110 
    111     MultinomialOpModel multinomial(kBatchSize, kNumClasses, kNumSamples);
    112     multinomial.Invoke();
    113 
    114     std::vector<uint32_t> output = multinomial.GetOutput();
    115     std::vector<int> class_counts;
    116     class_counts.resize(kNumClasses);
    117     for (auto index : output) {
    118         class_counts[index]++;
    119     }
    120 
    121     std::vector<float> input = multinomial.GetInput();
    122     for (int b = 0; b < kBatchSize; ++b) {
    123         float probability_sum = 0;
    124         const int batch_index = kBatchSize * b;
    125         for (int i = 0; i < kNumClasses; ++i) {
    126             probability_sum += expf(input[batch_index + i]);
    127         }
    128         for (int i = 0; i < kNumClasses; ++i) {
    129             float probability =
    130                     static_cast<float>(class_counts[i]) / static_cast<float>(kNumSamples);
    131             float probability_expected = expf(input[batch_index + i]) / probability_sum;
    132             EXPECT_THAT(probability, FloatNear(probability_expected, kMaxProbabilityDelta));
    133         }
    134     }
    135 }
    136 
    137 }  // namespace wrapper
    138 }  // namespace nn
    139 }  // namespace android
    140