Home | History | Annotate | Download | only in testing
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
     16 #define TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
     17 
     18 #include <memory>
     19 #include <string>
     20 #include <vector>
     21 #include "tensorflow/contrib/lite/string.h"
     22 
     23 namespace tflite {
     24 namespace testing {
     25 
     26 // This is the base class for processing test data. Each one of the virtual
     27 // methods must be implemented to forward the data to the appropriate executor
     28 // (e.g. TF Lite's interpreter, or the NNAPI).
     29 class TestRunner {
     30  public:
     31   TestRunner() {}
     32   virtual ~TestRunner() {}
     33 
     34   // Load the given model, as a path relative to SetModelBaseDir().
     35   virtual void LoadModel(const string& bin_file_path) = 0;
     36 
     37   // Return the list of input tensors in the loaded model.
     38   virtual const std::vector<int>& GetInputs() = 0;
     39 
     40   // Return the list of output tensors in the loaded model.
     41   virtual const std::vector<int>& GetOutputs() = 0;
     42 
     43   // Prepare for a run by resize the given tensor. The given 'id' is
     44   // guaranteed to be one of the ids returned by GetInputs().
     45   virtual void ReshapeTensor(int id, const string& csv_values) = 0;
     46 
     47   // Reserve memory for all tensors.
     48   virtual void AllocateTensors() = 0;
     49 
     50   // Set the given tensor to some initial state, usually zero. This is
     51   // used to reset persistent buffers in a model.
     52   virtual void ResetTensor(int id) = 0;
     53 
     54   // Define the contents of the given input tensor. The given 'id' is
     55   // guaranteed to be one of the ids returned by GetInputs().
     56   virtual void SetInput(int id, const string& csv_values) = 0;
     57 
     58   // Define what should be expected for an output tensor after Invoke() runs.
     59   // The given 'id' is guaranteed to be one of the ids returned by
     60   // GetOutputs().
     61   virtual void SetExpectation(int id, const string& csv_values) = 0;
     62 
     63   // Run the model.
     64   virtual void Invoke() = 0;
     65 
     66   // Verify that the contents of all outputs conform to the existing
     67   // expectations. Return true if there are no expectations or they are all
     68   // satisfied.
     69   virtual bool CheckResults() = 0;
     70 
     71   // Read contents of tensor into csv format.
     72   // The given 'id' is guaranteed to be one of the ids returned by GetOutputs().
     73   virtual string ReadOutput(int id) = 0;
     74 
     75   // Set the base path for loading models.
     76   void SetModelBaseDir(const string& path) {
     77     model_base_dir_ = path;
     78     if (path[path.length() - 1] != '/') {
     79       model_base_dir_ += "/";
     80     }
     81   }
     82 
     83   // Return the full path of a model.
     84   string GetFullPath(const string& path) { return model_base_dir_ + path; }
     85 
     86   // Give an id to the next invocation to make error reporting more meaningful.
     87   void SetInvocationId(const string& id) { invocation_id_ = id; }
     88   const string& GetInvocationId() const { return invocation_id_; }
     89 
     90   // Invalidate the test runner, preventing it from executing any further.
     91   void Invalidate(const string& error_message) {
     92     error_message_ = error_message;
     93   }
     94   bool IsValid() const { return error_message_.empty(); }
     95   const string& GetErrorMessage() const { return error_message_; }
     96 
     97   // Handle the overall success of this test runner. This will be true if all
     98   // invocations were successful.
     99   void SetOverallSuccess(bool value) { overall_success_ = value; }
    100   bool GetOverallSuccess() const { return overall_success_; }
    101 
    102  protected:
    103   // A helper to check of the given number of values is consistent with the
    104   // number of bytes in a tensor of type T. When incompatibles sizes are found,
    105   // the test runner is invalidated and false is returned.
    106   template <typename T>
    107   bool CheckSizes(size_t tensor_bytes, size_t num_values) {
    108     size_t num_tensor_elements = tensor_bytes / sizeof(T);
    109     if (num_tensor_elements != num_values) {
    110       Invalidate("Expected '" + std::to_string(num_tensor_elements) +
    111                  "' elements for a tensor, but only got '" +
    112                  std::to_string(num_values) + "'");
    113       return false;
    114     }
    115     return true;
    116   }
    117 
    118  private:
    119   string model_base_dir_;
    120   string invocation_id_;
    121   bool overall_success_ = true;
    122 
    123   string error_message_;
    124 };
    125 
    126 }  // namespace testing
    127 }  // namespace tflite
    128 #endif  // TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
    129