Home | History | Annotate | Download | only in ops
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <vector>
     17 
     18 #include <gtest/gtest.h>
     19 #include "tensorflow/contrib/lite/interpreter.h"
     20 #include "tensorflow/contrib/lite/kernels/register.h"
     21 #include "tensorflow/contrib/lite/kernels/test_util.h"
     22 #include "tensorflow/contrib/lite/model.h"
     23 #include "tensorflow/contrib/lite/string_util.h"
     24 
     25 namespace tflite {
     26 
     27 namespace ops {
     28 namespace custom {
     29 TfLiteRegistration* Register_PREDICT();
     30 
     31 namespace {
     32 
     33 using ::testing::ElementsAreArray;
     34 
     35 class PredictOpModel : public SingleOpModel {
     36  public:
     37   PredictOpModel(std::initializer_list<int> input_signature_shape,
     38                  std::initializer_list<int> key_shape,
     39                  std::initializer_list<int> labelweight_shape, int num_output,
     40                  float threshold) {
     41     input_signature_ = AddInput(TensorType_INT32);
     42     model_key_ = AddInput(TensorType_INT32);
     43     model_label_ = AddInput(TensorType_INT32);
     44     model_weight_ = AddInput(TensorType_FLOAT32);
     45     output_label_ = AddOutput(TensorType_INT32);
     46     output_weight_ = AddOutput(TensorType_FLOAT32);
     47 
     48     std::vector<uint8_t> predict_option;
     49     writeInt32(num_output, &predict_option);
     50     writeFloat32(threshold, &predict_option);
     51     SetCustomOp("Predict", predict_option, Register_PREDICT);
     52     BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
     53                        labelweight_shape}});
     54   }
     55 
     56   void SetInputSignature(std::initializer_list<int> data) {
     57     PopulateTensor<int>(input_signature_, data);
     58   }
     59 
     60   void SetModelKey(std::initializer_list<int> data) {
     61     PopulateTensor<int>(model_key_, data);
     62   }
     63 
     64   void SetModelLabel(std::initializer_list<int> data) {
     65     PopulateTensor<int>(model_label_, data);
     66   }
     67 
     68   void SetModelWeight(std::initializer_list<float> data) {
     69     PopulateTensor<float>(model_weight_, data);
     70   }
     71 
     72   std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
     73   std::vector<float> GetWeight() {
     74     return ExtractVector<float>(output_weight_);
     75   }
     76 
     77   void writeFloat32(float value, std::vector<uint8_t>* data) {
     78     union {
     79       float v;
     80       uint8_t r[4];
     81     } float_to_raw;
     82     float_to_raw.v = value;
     83     for (unsigned char i : float_to_raw.r) {
     84       data->push_back(i);
     85     }
     86   }
     87 
     88   void writeInt32(int32_t value, std::vector<uint8_t>* data) {
     89     union {
     90       int32_t v;
     91       uint8_t r[4];
     92     } int32_to_raw;
     93     int32_to_raw.v = value;
     94     for (unsigned char i : int32_to_raw.r) {
     95       data->push_back(i);
     96     }
     97   }
     98 
     99  private:
    100   int input_signature_;
    101   int model_key_;
    102   int model_label_;
    103   int model_weight_;
    104   int output_label_;
    105   int output_weight_;
    106 };
    107 
    108 TEST(PredictOpTest, AllLabelsAreValid) {
    109   PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
    110   m.SetInputSignature({1, 3, 7, 9});
    111   m.SetModelKey({1, 2, 4, 6, 7});
    112   m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
    113   m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
    114   m.Invoke();
    115   EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
    116   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
    117 }
    118 
    119 TEST(PredictOpTest, MoreLabelsThanRequired) {
    120   PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
    121   m.SetInputSignature({1, 3, 7, 9});
    122   m.SetModelKey({1, 2, 4, 6, 7});
    123   m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
    124   m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
    125   m.Invoke();
    126   EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
    127   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
    128 }
    129 
    130 TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
    131   PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
    132   m.SetInputSignature({1, 3, 7, 9});
    133   m.SetModelKey({1, 2, 4, 6, 7});
    134   m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
    135   m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
    136   m.Invoke();
    137   EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
    138   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
    139 }
    140 
    141 TEST(PredictOpTest, NoneLabelPassThreshold) {
    142   PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
    143   m.SetInputSignature({1, 3, 7, 9});
    144   m.SetModelKey({1, 2, 4, 6, 7});
    145   m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
    146   m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
    147   m.Invoke();
    148   EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
    149   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
    150 }
    151 
    152 TEST(PredictOpTest, OnlyOneLabelGenerated) {
    153   PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
    154   m.SetInputSignature({1, 3, 7, 9});
    155   m.SetModelKey({1, 2, 4, 6, 7});
    156   m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
    157   m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
    158   m.Invoke();
    159   EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
    160   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
    161 }
    162 
    163 TEST(PredictOpTest, NoLabelGenerated) {
    164   PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
    165   m.SetInputSignature({5, 3, 7, 9});
    166   m.SetModelKey({1, 2, 4, 6, 7});
    167   m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
    168   m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
    169   m.Invoke();
    170   EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
    171   EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
    172 }
    173 
    174 }  // namespace
    175 }  // namespace custom
    176 }  // namespace ops
    177 }  // namespace tflite
    178 
    179 int main(int argc, char** argv) {
    180   // On Linux, add: tflite::LogToStderr();
    181   ::testing::InitGoogleTest(&argc, argv);
    182   return RUN_ALL_TESTS();
    183 }
    184