Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
     16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
     17 
     18 #include <vector>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 
     23 #include "tensorflow/contrib/lite/interpreter.h"
     24 #include "tensorflow/contrib/lite/kernels/register.h"
     25 #include "tensorflow/contrib/lite/model.h"
     26 #include "tensorflow/contrib/lite/string_util.h"
     27 #include "tensorflow/contrib/lite/testing/util.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 
     30 namespace tflite {
     31 
     32 // A gmock matcher that check that elements of a float vector match to a given
     33 // tolerance.
     34 std::vector<::testing::Matcher<float>> ArrayFloatNear(
     35     const std::vector<float>& values, float max_abs_error = 1e-5);
     36 
     37 template <typename T>
     38 inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
     39                                int32_t zero_point) {
     40   std::vector<T> q;
     41   for (float f : data) {
     42     q.push_back(std::max(
     43         std::numeric_limits<T>::min(),
     44         std::min(std::numeric_limits<T>::max(),
     45                  static_cast<T>(std::round(zero_point + (f / scale))))));
     46   }
     47   return q;
     48 }
     49 
     50 template <typename T>
     51 inline std::vector<float> Dequantize(const std::vector<T>& data, float scale,
     52                                      int32_t zero_point) {
     53   std::vector<float> f;
     54   for (T q : data) {
     55     f.push_back(scale * (q - zero_point));
     56   }
     57   return f;
     58 }
     59 
     60 // A test model that contains a single operator. All operator inputs and
     61 // output are external to the model, so the tests can directly access them.
     62 // Typical usage:
     63 //    SingleOpModel m;
     64 //    int a = m.AddInput({TensorType_FLOAT32, a_shape});
     65 //    int b = m.AddInput({TensorType_FLOAT32, b_shape});
     66 //    int c = m.AddOutput({TensorType_FLOAT32, {}});
     67 //    m.SetBuiltinOp(...);
     68 //    m.BuildInterpreter({GetShape(a), GetShape(b)});
     69 //    m.PopulateTensor(a, {...});
     70 //    m.PopulateTensor(b, {...});
     71 //    m.Invoke();
     72 //    EXPECT_THAT(m.ExtractVector<float>(c), ArrayFloatNear({...}));
     73 //
     74 
     75 // A helper struct to construct test tensors. This is particularly useful for
     76 // quantized tensor which must have their scale and zero_point defined before
     77 // the actual data is known. This mimics what happens in practice: quantization
     78 // parameters are calculate during training.
     79 struct TensorData {
     80   TensorType type;
     81   std::vector<int> shape;
     82   float min;
     83   float max;
     84   float scale;
     85   int32_t zero_point;
     86 };
     87 
     88 class SingleOpResolver : public OpResolver {
     89  public:
     90   SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
     91       : op_(op), registration_(registration) {}
     92   TfLiteRegistration* FindOp(BuiltinOperator op) const override {
     93     if (op == op_) {
     94       return registration_;
     95     }
     96     return nullptr;
     97   }
     98   TfLiteRegistration* FindOp(const char* op) const override { return nullptr; }
     99 
    100  private:
    101   const BuiltinOperator op_;
    102   TfLiteRegistration* registration_;
    103 };
    104 
    105 class SingleOpModel {
    106  public:
    107   SingleOpModel() {}
    108   ~SingleOpModel() {}
    109 
    110   // Copying or assignment is disallowed to simplify ownership semantics.
    111   SingleOpModel(const SingleOpModel&) = delete;
    112   SingleOpModel& operator=(const SingleOpModel&) = delete;
    113 
    114   // Add a TensorType input tensor and return its index.
    115   int AddInput(TensorType type) { return AddInput(TensorData{type}); }
    116   int AddInput(const TensorData& t);
    117 
    118   // Add a Tensor containing const data and return the tensor id.
    119   int AddConstInput(TensorType type, std::initializer_list<int> data,
    120                     std::initializer_list<int> shape);
    121 
    122   // Add a null input tensor (optional input) and return kOptionalTensor.
    123   int AddNullInput();
    124 
    125   // Add a TensorType output tensor and return its index.
    126   int AddOutput(TensorType type) { return AddOutput(TensorData{type}); }
    127   int AddOutput(const TensorData& t);
    128 
    129   template <typename T>
    130   void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
    131     TfLiteTensor* t = interpreter_->tensor(index);
    132     auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
    133     PopulateTensor(index, 0, q.data(), q.data() + q.size());
    134   }
    135 
    136   const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
    137 
    138   float GetScale(int id) { return tensor_data_.at(id).scale; }
    139   int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; }
    140 
    141   // Define the operator in this model.
    142   void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type,
    143                     flatbuffers::Offset<void> builtin_options);
    144   void SetCustomOp(const string& name,
    145                    const std::vector<uint8_t>& custom_option,
    146                    const std::function<TfLiteRegistration*()>& registeration);
    147 
    148   // Build the interpreter for this model. Also, resize and allocate all
    149   // tensors given the shapes of the inputs.
    150   void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
    151 
    152   void Invoke();
    153 
    154   void PopulateStringTensor(int index, const std::vector<string>& content) {
    155     auto tensor = interpreter_->tensor(index);
    156     DynamicBuffer buf;
    157     for (const string& s : content) {
    158       buf.AddString(s.data(), s.length());
    159     }
    160     buf.WriteToTensor(tensor);
    161   }
    162 
    163   // Populate the tensor given its index.
    164   template <typename T>
    165   void PopulateTensor(int index, std::initializer_list<T> data) {
    166     T* v = interpreter_->typed_tensor<T>(index);
    167     CHECK(v) << "No tensor with index '" << index << "'.";
    168     for (T f : data) {
    169       *v = f;
    170       ++v;
    171     }
    172   }
    173 
    174   // Partially populate the tensor, starting at the given offset.
    175   template <typename T>
    176   void PopulateTensor(int index, int offset, T* begin, T* end) {
    177     T* v = interpreter_->typed_tensor<T>(index);
    178     memcpy(v + offset, begin, (end - begin) * sizeof(T));
    179   }
    180 
    181   // Return a vector with the flattened contents of a tensor.
    182   template <typename T>
    183   std::vector<T> ExtractVector(int index) {
    184     T* v = interpreter_->typed_tensor<T>(index);
    185     CHECK(v);
    186     return std::vector<T>(v, v + GetTensorSize(index));
    187   }
    188 
    189   std::vector<int> GetTensorShape(int index) {
    190     std::vector<int> result;
    191     TfLiteTensor* t = interpreter_->tensor(index);
    192     for (int i = 0; i < t->dims->size; ++i) {
    193       result.push_back(t->dims->data[i]);
    194     }
    195     return result;
    196   }
    197 
    198   void SetResolver(std::unique_ptr<OpResolver> resolver) {
    199     resolver_ = std::move(resolver);
    200   }
    201 
    202  protected:
    203   int32_t GetTensorSize(int index) const;
    204 
    205   flatbuffers::FlatBufferBuilder builder_;
    206   std::unique_ptr<tflite::Interpreter> interpreter_;
    207   std::unique_ptr<OpResolver> resolver_;
    208 
    209  private:
    210   int AddTensor(TensorData t, std::initializer_list<int> data);
    211 
    212   std::map<int, TensorData> tensor_data_;
    213   std::vector<int32_t> inputs_;
    214   std::vector<int32_t> outputs_;
    215   std::vector<flatbuffers::Offset<Tensor>> tensors_;
    216   std::vector<flatbuffers::Offset<OperatorCode>> opcodes_;
    217   std::vector<flatbuffers::Offset<Operator>> operators_;
    218   std::vector<flatbuffers::Offset<Buffer>> buffers_;
    219   std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
    220 };
    221 
    222 // Base class for single op unit tests.
    223 // The tests are parameterized to test multiple kernels for a single op.
    224 // The parameters are strings like "optimized" and "reference" to have better
    225 // readability in test reports.
    226 //
    227 // To use this class:
    228 // * Define a constant map from strings to TfLiteRegistration.
    229 // * Implement a test class that inherits SingleOpTest.
    230 // * Instantiate the test cases with SingleOpTest::GetKernelTags helper
    231 //   function.
    232 // * Call GetRegistration to get the TfLiteRegistration to be used before
    233 //   building the interpreter.
    234 class SingleOpTest : public ::testing::TestWithParam<string> {
    235  public:
    236   static std::vector<string> GetKernelTags(
    237       const std::map<string, TfLiteRegistration*>& kernel_map) {
    238     std::vector<string> tags;
    239     for (auto it : kernel_map) {
    240       tags.push_back(it.first);
    241     }
    242     return tags;
    243   }
    244 
    245  protected:
    246   virtual const std::map<string, TfLiteRegistration*>& GetKernelMap() = 0;
    247   TfLiteRegistration* GetRegistration() {
    248     return GetKernelMap().at(GetParam());
    249   }
    250 };
    251 
    252 // Strings have a special implementation that is in test_util.cc
    253 template <>
    254 std::vector<string> SingleOpModel::ExtractVector(int index);
    255 }  // namespace tflite
    256 
    257 #endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
    258