1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include <gtest/gtest.h> 19 #include "tensorflow/contrib/lite/interpreter.h" 20 #include "tensorflow/contrib/lite/kernels/register.h" 21 #include "tensorflow/contrib/lite/kernels/test_util.h" 22 #include "tensorflow/contrib/lite/model.h" 23 #include "tensorflow/contrib/lite/string_util.h" 24 25 namespace tflite { 26 27 namespace ops { 28 namespace custom { 29 TfLiteRegistration* Register_PREDICT(); 30 31 namespace { 32 33 using ::testing::ElementsAreArray; 34 35 class PredictOpModel : public SingleOpModel { 36 public: 37 PredictOpModel(std::initializer_list<int> input_signature_shape, 38 std::initializer_list<int> key_shape, 39 std::initializer_list<int> labelweight_shape, int num_output, 40 float threshold) { 41 input_signature_ = AddInput(TensorType_INT32); 42 model_key_ = AddInput(TensorType_INT32); 43 model_label_ = AddInput(TensorType_INT32); 44 model_weight_ = AddInput(TensorType_FLOAT32); 45 output_label_ = AddOutput(TensorType_INT32); 46 output_weight_ = AddOutput(TensorType_FLOAT32); 47 48 std::vector<uint8_t> predict_option; 49 writeInt32(num_output, &predict_option); 50 writeFloat32(threshold, &predict_option); 51 SetCustomOp("Predict", predict_option, Register_PREDICT); 52 BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape, 53 labelweight_shape}}); 54 } 55 56 void SetInputSignature(std::initializer_list<int> data) { 57 PopulateTensor<int>(input_signature_, data); 58 } 59 60 void SetModelKey(std::initializer_list<int> data) { 61 PopulateTensor<int>(model_key_, data); 62 } 63 64 void SetModelLabel(std::initializer_list<int> data) { 65 PopulateTensor<int>(model_label_, data); 66 } 67 68 void SetModelWeight(std::initializer_list<float> data) { 69 PopulateTensor<float>(model_weight_, data); 70 } 71 72 std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); } 73 std::vector<float> GetWeight() { 74 return ExtractVector<float>(output_weight_); 75 } 76 77 void writeFloat32(float value, std::vector<uint8_t>* data) { 78 union { 79 float v; 80 uint8_t r[4]; 81 } float_to_raw; 82 float_to_raw.v = value; 83 for (unsigned char i : float_to_raw.r) { 84 data->push_back(i); 85 } 86 } 87 88 void writeInt32(int32_t value, std::vector<uint8_t>* data) { 89 union { 90 int32_t v; 91 uint8_t r[4]; 92 } int32_to_raw; 93 int32_to_raw.v = value; 94 for (unsigned char i : int32_to_raw.r) { 95 data->push_back(i); 96 } 97 } 98 99 private: 100 int input_signature_; 101 int model_key_; 102 int model_label_; 103 int model_weight_; 104 int output_label_; 105 int output_weight_; 106 }; 107 108 TEST(PredictOpTest, AllLabelsAreValid) { 109 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 110 m.SetInputSignature({1, 3, 7, 9}); 111 m.SetModelKey({1, 2, 4, 6, 7}); 112 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 113 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 114 m.Invoke(); 115 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11})); 116 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05}))); 117 } 118 119 TEST(PredictOpTest, MoreLabelsThanRequired) { 120 PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001); 121 m.SetInputSignature({1, 3, 7, 9}); 122 m.SetModelKey({1, 2, 4, 6, 7}); 123 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 124 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 125 m.Invoke(); 126 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12})); 127 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1}))); 128 } 129 130 TEST(PredictOpTest, OneLabelDoesNotPassThreshold) { 131 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07); 132 m.SetInputSignature({1, 3, 7, 9}); 133 m.SetModelKey({1, 2, 4, 6, 7}); 134 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 135 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 136 m.Invoke(); 137 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1})); 138 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0}))); 139 } 140 141 TEST(PredictOpTest, NoneLabelPassThreshold) { 142 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6); 143 m.SetInputSignature({1, 3, 7, 9}); 144 m.SetModelKey({1, 2, 4, 6, 7}); 145 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 146 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 147 m.Invoke(); 148 EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); 149 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); 150 } 151 152 TEST(PredictOpTest, OnlyOneLabelGenerated) { 153 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 154 m.SetInputSignature({1, 3, 7, 9}); 155 m.SetModelKey({1, 2, 4, 6, 7}); 156 m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0}); 157 m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0}); 158 m.Invoke(); 159 EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1})); 160 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0}))); 161 } 162 163 TEST(PredictOpTest, NoLabelGenerated) { 164 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 165 m.SetInputSignature({5, 3, 7, 9}); 166 m.SetModelKey({1, 2, 4, 6, 7}); 167 m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0}); 168 m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}); 169 m.Invoke(); 170 EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); 171 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); 172 } 173 174 } // namespace 175 } // namespace custom 176 } // namespace ops 177 } // namespace tflite 178 179 int main(int argc, char** argv) { 180 // On Linux, add: tflite::LogToStderr(); 181 ::testing::InitGoogleTest(&argc, argv); 182 return RUN_ALL_TESTS(); 183 } 184