Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tflite/import.h"
     16 
     17 #include "flatbuffers/flexbuffers.h"
     18 #include <gmock/gmock.h>
     19 #include <gtest/gtest.h>
     20 #include "tensorflow/contrib/lite/schema/schema_generated.h"
     21 #include "tensorflow/contrib/lite/version.h"
     22 
     23 namespace toco {
     24 
     25 namespace tflite {
     26 namespace {
     27 
     28 using ::testing::ElementsAre;
     29 
     30 class ImportTest : public ::testing::Test {
     31  protected:
     32   template <typename T>
     33   flatbuffers::Offset<flatbuffers::Vector<unsigned char>> CreateDataVector(
     34       const std::vector<T>& data) {
     35     return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
     36                                  sizeof(T) * data.size());
     37   }
     38   // This is a very simplistic model. We are not interested in testing all the
     39   // details here, since tf.mini's testing framework will be exercising all the
     40   // conversions multiple times, and the conversion of operators is tested by
     41   // separate unittests.
     42   void BuildTestModel() {
     43     // The tensors
     44     auto q = ::tflite::CreateQuantizationParameters(
     45         builder_,
     46         /*min=*/builder_.CreateVector<float>({0.1f}),
     47         /*max=*/builder_.CreateVector<float>({0.2f}),
     48         /*scale=*/builder_.CreateVector<float>({0.3f}),
     49         /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
     50     auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
     51     auto buf1 =
     52         ::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f}));
     53     auto buf2 =
     54         ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f}));
     55     auto buffers = builder_.CreateVector(
     56         std::vector<flatbuffers::Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
     57     auto t1 = ::tflite::CreateTensor(builder_,
     58                                      builder_.CreateVector<int>({1, 2, 3, 4}),
     59                                      ::tflite::TensorType_FLOAT32, 1,
     60                                      builder_.CreateString("tensor_one"), q);
     61     auto t2 =
     62         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
     63                                ::tflite::TensorType_FLOAT32, 2,
     64                                builder_.CreateString("tensor_two"), q);
     65     auto tensors = builder_.CreateVector(
     66         std::vector<flatbuffers::Offset<::tflite::Tensor>>({t1, t2}));
     67 
     68     // The operator codes.
     69     auto c1 =
     70         ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM,
     71                                      builder_.CreateString("custom_op_one"));
     72     auto c2 = ::tflite::CreateOperatorCode(
     73         builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
     74     auto opcodes = builder_.CreateVector(
     75         std::vector<flatbuffers::Offset<::tflite::OperatorCode>>({c1, c2}));
     76 
     77     auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0);
     78     std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector(
     79         {subgraph});
     80     auto subgraphs = builder_.CreateVector(subgraph_vector);
     81     auto s = builder_.CreateString("");
     82     builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
     83                                           opcodes, subgraphs, s, buffers));
     84 
     85     input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
     86   }
     87   string InputModelAsString() {
     88     return string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
     89                   builder_.GetSize());
     90   }
     91   flatbuffers::FlatBufferBuilder builder_;
     92   // const uint8_t* buffer_ = nullptr;
     93   const ::tflite::Model* input_model_ = nullptr;
     94 };
     95 
     96 TEST_F(ImportTest, LoadTensorsTable) {
     97   BuildTestModel();
     98 
     99   details::TensorsTable tensors;
    100   details::LoadTensorsTable(*input_model_, &tensors);
    101   EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
    102 }
    103 
    104 TEST_F(ImportTest, LoadOperatorsTable) {
    105   BuildTestModel();
    106 
    107   details::OperatorsTable operators;
    108   details::LoadOperatorsTable(*input_model_, &operators);
    109   EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D"));
    110 }
    111 
    112 TEST_F(ImportTest, Tensors) {
    113   BuildTestModel();
    114 
    115   auto model = Import(ModelFlags(), InputModelAsString());
    116 
    117   ASSERT_GT(model->HasArray("tensor_one"), 0);
    118   Array& a1 = model->GetArray("tensor_one");
    119   EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
    120   EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
    121               ElementsAre(1.0f, 2.0f));
    122   ASSERT_TRUE(a1.has_shape());
    123   EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4));
    124 
    125   const auto& mm = a1.minmax;
    126   ASSERT_TRUE(mm.get());
    127   EXPECT_FLOAT_EQ(0.1, mm->min);
    128   EXPECT_FLOAT_EQ(0.2, mm->max);
    129 
    130   const auto& q = a1.quantization_params;
    131   ASSERT_TRUE(q.get());
    132   EXPECT_FLOAT_EQ(0.3, q->scale);
    133   EXPECT_EQ(100, q->zero_point);
    134 }
    135 
    136 // TODO(ahentz): still need tests for Operators and IOTensors.
    137 
    138 }  // namespace
    139 }  // namespace tflite
    140 
    141 }  // namespace toco
    142