1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tflite/import.h" 16 17 #include "flatbuffers/flexbuffers.h" 18 #include <gmock/gmock.h> 19 #include <gtest/gtest.h> 20 #include "tensorflow/contrib/lite/schema/schema_generated.h" 21 #include "tensorflow/contrib/lite/version.h" 22 23 namespace toco { 24 25 namespace tflite { 26 namespace { 27 28 using ::testing::ElementsAre; 29 30 class ImportTest : public ::testing::Test { 31 protected: 32 template <typename T> 33 flatbuffers::Offset<flatbuffers::Vector<unsigned char>> CreateDataVector( 34 const std::vector<T>& data) { 35 return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()), 36 sizeof(T) * data.size()); 37 } 38 // This is a very simplistic model. We are not interested in testing all the 39 // details here, since tf.mini's testing framework will be exercising all the 40 // conversions multiple times, and the conversion of operators is tested by 41 // separate unittests. 42 void BuildTestModel() { 43 // The tensors 44 auto q = ::tflite::CreateQuantizationParameters( 45 builder_, 46 /*min=*/builder_.CreateVector<float>({0.1f}), 47 /*max=*/builder_.CreateVector<float>({0.2f}), 48 /*scale=*/builder_.CreateVector<float>({0.3f}), 49 /*zero_point=*/builder_.CreateVector<int64_t>({100ll})); 50 auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({})); 51 auto buf1 = 52 ::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f})); 53 auto buf2 = 54 ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f})); 55 auto buffers = builder_.CreateVector( 56 std::vector<flatbuffers::Offset<::tflite::Buffer>>({buf0, buf1, buf2})); 57 auto t1 = ::tflite::CreateTensor(builder_, 58 builder_.CreateVector<int>({1, 2, 3, 4}), 59 ::tflite::TensorType_FLOAT32, 1, 60 builder_.CreateString("tensor_one"), q); 61 auto t2 = 62 ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}), 63 ::tflite::TensorType_FLOAT32, 2, 64 builder_.CreateString("tensor_two"), q); 65 auto tensors = builder_.CreateVector( 66 std::vector<flatbuffers::Offset<::tflite::Tensor>>({t1, t2})); 67 68 // The operator codes. 69 auto c1 = 70 ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM, 71 builder_.CreateString("custom_op_one")); 72 auto c2 = ::tflite::CreateOperatorCode( 73 builder_, ::tflite::BuiltinOperator_CONV_2D, 0); 74 auto opcodes = builder_.CreateVector( 75 std::vector<flatbuffers::Offset<::tflite::OperatorCode>>({c1, c2})); 76 77 auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0); 78 std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector( 79 {subgraph}); 80 auto subgraphs = builder_.CreateVector(subgraph_vector); 81 auto s = builder_.CreateString(""); 82 builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, 83 opcodes, subgraphs, s, buffers)); 84 85 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); 86 } 87 string InputModelAsString() { 88 return string(reinterpret_cast<char*>(builder_.GetBufferPointer()), 89 builder_.GetSize()); 90 } 91 flatbuffers::FlatBufferBuilder builder_; 92 // const uint8_t* buffer_ = nullptr; 93 const ::tflite::Model* input_model_ = nullptr; 94 }; 95 96 TEST_F(ImportTest, LoadTensorsTable) { 97 BuildTestModel(); 98 99 details::TensorsTable tensors; 100 details::LoadTensorsTable(*input_model_, &tensors); 101 EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two")); 102 } 103 104 TEST_F(ImportTest, LoadOperatorsTable) { 105 BuildTestModel(); 106 107 details::OperatorsTable operators; 108 details::LoadOperatorsTable(*input_model_, &operators); 109 EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D")); 110 } 111 112 TEST_F(ImportTest, Tensors) { 113 BuildTestModel(); 114 115 auto model = Import(ModelFlags(), InputModelAsString()); 116 117 ASSERT_GT(model->HasArray("tensor_one"), 0); 118 Array& a1 = model->GetArray("tensor_one"); 119 EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); 120 EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data, 121 ElementsAre(1.0f, 2.0f)); 122 ASSERT_TRUE(a1.has_shape()); 123 EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4)); 124 125 const auto& mm = a1.minmax; 126 ASSERT_TRUE(mm.get()); 127 EXPECT_FLOAT_EQ(0.1, mm->min); 128 EXPECT_FLOAT_EQ(0.2, mm->max); 129 130 const auto& q = a1.quantization_params; 131 ASSERT_TRUE(q.get()); 132 EXPECT_FLOAT_EQ(0.3, q->scale); 133 EXPECT_EQ(100, q->zero_point); 134 } 135 136 // TODO(ahentz): still need tests for Operators and IOTensors. 137 138 } // namespace 139 } // namespace tflite 140 141 } // namespace toco 142