Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     18 #include <vector>
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     23 #include "tensorflow/contrib/lite/interpreter.h"
     24 #include "tensorflow/contrib/lite/kernels/register.h"
     25 #include "tensorflow/contrib/lite/model.h"
     26 #include "tensorflow/contrib/lite/string_util.h"
     27 #include "tensorflow/contrib/lite/testing/util.h"
     28 #include "tensorflow/core/platform/logging.h"
     30 namespace tflite {
     32 // A gmock matcher that check that elements of a float vector match to a given
     33 // tolerance.
     34 std::vector<::testing::Matcher<float>> ArrayFloatNear(
     35     const std::vector<float>& values, float max_abs_error = 1e-5);
     37 template <typename T>
     38 inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
     39                                int32_t zero_point) {
     40   std::vector<T> q;
     41   for (float f : data) {
     42     q.push_back(std::max(
     43         std::numeric_limits<T>::min(),
     44         std::min(std::numeric_limits<T>::max(),
     45                  static_cast<T>(std::round(zero_point + (f / scale))))));
     46   }
     47   return q;
     48 }
     50 template <typename T>
     51 inline std::vector<float> Dequantize(const std::vector<T>& data, float scale,
     52                                      int32_t zero_point) {
     53   std::vector<float> f;
     54   for (T q : data) {
     55     f.push_back(scale * (q - zero_point));
     56   }
     57   return f;
     58 }
     60 // A test model that contains a single operator. All operator inputs and
     61 // output are external to the model, so the tests can directly access them.
     62 // Typical usage:
     63 //    SingleOpModel m;
     64 //    int a = m.AddInput({TensorType_FLOAT32, a_shape});
     65 //    int b = m.AddInput({TensorType_FLOAT32, b_shape});
     66 //    int c = m.AddOutput({TensorType_FLOAT32, {}});
     67 //    m.SetBuiltinOp(...);
     68 //    m.BuildInterpreter({GetShape(a), GetShape(b)});
     69 //    m.PopulateTensor(a, {...});
     70 //    m.PopulateTensor(b, {...});
     71 //    m.Invoke();
     72 //    EXPECT_THAT(m.ExtractVector<float>(c), ArrayFloatNear({...}));
     73 //
     75 // A helper struct to construct test tensors. This is particularly useful for
     76 // quantized tensor which must have their scale and zero_point defined before
     77 // the actual data is known. This mimics what happens in practice: quantization
     78 // parameters are calculate during training.
     79 struct TensorData {
     80   TensorType type;
     81   std::vector<int> shape;
     82   float min;
     83   float max;
     84   float scale;
     85   int32_t zero_point;
     86 };
     88 class SingleOpResolver : public OpResolver {
     89  public:
     90   SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
     91       : op_(op), registration_(registration) {}
     92   TfLiteRegistration* FindOp(BuiltinOperator op) const override {
     93     if (op == op_) {
     94       return registration_;
     95     }
     96     return nullptr;
     97   }
     98   TfLiteRegistration* FindOp(const char* op) const override { return nullptr; }
    100  private:
    101   const BuiltinOperator op_;
    102   TfLiteRegistration* registration_;
    103 };
    105 class SingleOpModel {
    106  public:
    107   SingleOpModel() {}
    108   ~SingleOpModel() {}
    110   // Copying or assignment is disallowed to simplify ownership semantics.
    111   SingleOpModel(const SingleOpModel&) = delete;
    112   SingleOpModel& operator=(const SingleOpModel&) = delete;
    114   // Add a TensorType input tensor and return its index.
    115   int AddInput(TensorType type) { return AddInput(TensorData{type}); }
    116   int AddInput(const TensorData& t);
    118   // Add a Tensor containing const data and return the tensor id.
    119   int AddConstInput(TensorType type, std::initializer_list<int> data,
    120                     std::initializer_list<int> shape);
    122   // Add a null input tensor (optional input) and return kOptionalTensor.
    123   int AddNullInput();
    125   // Add a TensorType output tensor and return its index.
    126   int AddOutput(TensorType type) { return AddOutput(TensorData{type}); }
    127   int AddOutput(const TensorData& t);
    129   template <typename T>
    130   void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
    131     TfLiteTensor* t = interpreter_->tensor(index);
    132     auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
    133     PopulateTensor(index, 0, q.data(), q.data() + q.size());
    134   }
    136   const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
    138   float GetScale(int id) { return tensor_data_.at(id).scale; }
    139   int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; }
    141   // Define the operator in this model.
    142   void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type,
    143                     flatbuffers::Offset<void> builtin_options);
    144   void SetCustomOp(const string& name,
    145                    const std::vector<uint8_t>& custom_option,
    146                    const std::function<TfLiteRegistration*()>& registeration);
    148   // Build the interpreter for this model. Also, resize and allocate all
    149   // tensors given the shapes of the inputs.
    150   void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
    152   void Invoke();
    154   void PopulateStringTensor(int index, const std::vector<string>& content) {
    155     auto tensor = interpreter_->tensor(index);
    156     DynamicBuffer buf;
    157     for (const string& s : content) {
    158       buf.AddString(s.data(), s.length());
    159     }
    160     buf.WriteToTensor(tensor);
    161   }
    163   // Populate the tensor given its index.
    164   template <typename T>
    165   void PopulateTensor(int index, std::initializer_list<T> data) {
    166     T* v = interpreter_->typed_tensor<T>(index);
    167     CHECK(v) << "No tensor with index '" << index << "'.";
    168     for (T f : data) {
    169       *v = f;
    170       ++v;
    171     }
    172   }
    174   // Partially populate the tensor, starting at the given offset.
    175   template <typename T>
    176   void PopulateTensor(int index, int offset, T* begin, T* end) {
    177     T* v = interpreter_->typed_tensor<T>(index);
    178     memcpy(v + offset, begin, (end - begin) * sizeof(T));
    179   }
    181   // Return a vector with the flattened contents of a tensor.
    182   template <typename T>
    183   std::vector<T> ExtractVector(int index) {
    184     T* v = interpreter_->typed_tensor<T>(index);
    185     CHECK(v);
    186     return std::vector<T>(v, v + GetTensorSize(index));
    187   }
    189   std::vector<int> GetTensorShape(int index) {
    190     std::vector<int> result;
    191     TfLiteTensor* t = interpreter_->tensor(index);
    192     for (int i = 0; i < t->dims->size; ++i) {
    193       result.push_back(t->dims->data[i]);
    194     }
    195     return result;
    196   }
    198   void SetResolver(std::unique_ptr<OpResolver> resolver) {
    199     resolver_ = std::move(resolver);
    200   }
    202  protected:
    203   int32_t GetTensorSize(int index) const;
    205   flatbuffers::FlatBufferBuilder builder_;
    206   std::unique_ptr<tflite::Interpreter> interpreter_;
    207   std::unique_ptr<OpResolver> resolver_;
    209  private:
    210   int AddTensor(TensorData t, std::initializer_list<int> data);
    212   std::map<int, TensorData> tensor_data_;
    213   std::vector<int32_t> inputs_;
    214   std::vector<int32_t> outputs_;
    215   std::vector<flatbuffers::Offset<Tensor>> tensors_;
    216   std::vector<flatbuffers::Offset<OperatorCode>> opcodes_;
    217   std::vector<flatbuffers::Offset<Operator>> operators_;
    218   std::vector<flatbuffers::Offset<Buffer>> buffers_;
    219   std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
    220 };
    222 // Base class for single op unit tests.
    223 // The tests are parameterized to test multiple kernels for a single op.
    224 // The parameters are strings like "optimized" and "reference" to have better
    225 // readability in test reports.
    226 //
    227 // To use this class:
    228 // * Define a constant map from strings to TfLiteRegistration.
    229 // * Implement a test class that inherits SingleOpTest.
    230 // * Instantiate the test cases with SingleOpTest::GetKernelTags helper
    231 //   function.
    232 // * Call GetRegistration to get the TfLiteRegistration to be used before
    233 //   building the interpreter.
    234 class SingleOpTest : public ::testing::TestWithParam<string> {
    235  public:
    236   static std::vector<string> GetKernelTags(
    237       const std::map<string, TfLiteRegistration*>& kernel_map) {
    238     std::vector<string> tags;
    239     for (auto it : kernel_map) {
    240       tags.push_back(it.first);
    241     }
    242     return tags;
    243   }
    245  protected:
    246   virtual const std::map<string, TfLiteRegistration*>& GetKernelMap() = 0;
    247   TfLiteRegistration* GetRegistration() {
    248     return GetKernelMap().at(GetParam());
    249   }
    250 };
    252 // Strings have a special implementation that is in test_util.cc
    253 template <>
    254 std::vector<string> SingleOpModel::ExtractVector(int index);
    255 }  // namespace tflite