Home | History | Annotate | Download | only in testing
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Parses tflite example input data.
     16 // Format is ASCII
     17 // TODO(aselle): Switch to protobuf, but the android team requested a simple
     18 // ASCII file.
     19 #include "tensorflow/contrib/lite/testing/parse_testdata.h"
     20 
     21 #include <cinttypes>
     22 #include <cmath>
     23 #include <cstdint>
     24 #include <cstdio>
     25 #include <fstream>
     26 #include <iostream>
     27 #include <streambuf>
     28 
     29 #include "tensorflow/contrib/lite/error_reporter.h"
     30 #include "tensorflow/contrib/lite/testing/message.h"
     31 #include "tensorflow/contrib/lite/testing/split.h"
     32 
     33 namespace tflite {
     34 namespace testing {
     35 namespace {
     36 
     37 // Fatal error if parse error occurs
     38 #define PARSE_CHECK_EQ(filename, current_line, x, y)                         \
     39   if ((x) != (y)) {                                                          \
     40     fprintf(stderr, "Parse Error @ %s:%d\n  File %s\n  Line %d, %s != %s\n", \
     41             __FILE__, __LINE__, filename, current_line + 1, #x, #y);         \
     42     return kTfLiteError;                                                     \
     43   }
     44 
     45 // Breakup a "," delimited line into a std::vector<std::string>.
     46 // This is extremely inefficient, and just used for testing code.
     47 // TODO(aselle): replace with absl when we use it.
     48 std::vector<std::string> ParseLine(const std::string& line) {
     49   size_t pos = 0;
     50   std::vector<std::string> elements;
     51   while (true) {
     52     size_t end = line.find(',', pos);
     53     if (end == std::string::npos) {
     54       elements.push_back(line.substr(pos));
     55       break;
     56     } else {
     57       elements.push_back(line.substr(pos, end - pos));
     58     }
     59     pos = end + 1;
     60   }
     61   return elements;
     62 }
     63 
     64 }  // namespace
     65 
     66 // Given a `filename`, produce a vector of Examples corresopnding
     67 // to test cases that can be applied to a tflite model.
     68 TfLiteStatus ParseExamples(const char* filename,
     69                            std::vector<Example>* examples) {
     70   std::ifstream fp(filename);
     71   if (!fp.good()) {
     72     fprintf(stderr, "Could not read '%s'\n", filename);
     73     return kTfLiteError;
     74   }
     75   std::string str((std::istreambuf_iterator<char>(fp)),
     76                   std::istreambuf_iterator<char>());
     77   size_t pos = 0;
     78 
     79   // \n and , delimit parse a file.
     80   std::vector<std::vector<std::string>> csv;
     81   while (true) {
     82     size_t end = str.find('\n', pos);
     83 
     84     if (end == std::string::npos) {
     85       csv.emplace_back(ParseLine(str.substr(pos)));
     86       break;
     87     }
     88     csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
     89     pos = end + 1;
     90   }
     91 
     92   int current_line = 0;
     93   PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
     94   int example_count = std::stoi(csv[0][1]);
     95   current_line++;
     96 
     97   auto parse_tensor = [&filename, &current_line,
     98                        &csv](FloatTensor* tensor_ptr) {
     99     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
    100     current_line++;
    101     // parse shape
    102     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
    103     size_t elements = 1;
    104     FloatTensor& tensor = *tensor_ptr;
    105 
    106     for (size_t i = 1; i < csv[current_line].size(); i++) {
    107       const auto& shape_part_to_parse = csv[current_line][i];
    108       if (shape_part_to_parse.empty()) {
    109         // Case of a 0-dimensional shape
    110         break;
    111       }
    112       int shape_part = std::stoi(shape_part_to_parse);
    113       elements *= shape_part;
    114       tensor.shape.push_back(shape_part);
    115     }
    116     current_line++;
    117     // parse data
    118     PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
    119                    elements);
    120     for (size_t i = 1; i < csv[current_line].size(); i++) {
    121       tensor.flat_data.push_back(std::stof(csv[current_line][i]));
    122     }
    123     current_line++;
    124 
    125     return kTfLiteOk;
    126   };
    127 
    128   for (int example_idx = 0; example_idx < example_count; example_idx++) {
    129     Example example;
    130     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
    131     int inputs = std::stoi(csv[current_line][1]);
    132     current_line++;
    133     // parse dtype
    134     for (int input_index = 0; input_index < inputs; input_index++) {
    135       example.inputs.push_back(FloatTensor());
    136       TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
    137     }
    138 
    139     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
    140     int outputs = std::stoi(csv[current_line][1]);
    141     current_line++;
    142     for (int input_index = 0; input_index < outputs; input_index++) {
    143       example.outputs.push_back(FloatTensor());
    144       TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
    145     }
    146     examples->emplace_back(example);
    147   }
    148   return kTfLiteOk;
    149 }
    150 
    151 TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
    152                          const Example& example) {
    153   // Resize inputs to match example & allocate.
    154   for (size_t i = 0; i < interpreter->inputs().size(); i++) {
    155     int input_index = interpreter->inputs()[i];
    156 
    157     TF_LITE_ENSURE_STATUS(
    158         interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
    159   }
    160   TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
    161   // Copy data into tensors.
    162   for (size_t i = 0; i < interpreter->inputs().size(); i++) {
    163     int input_index = interpreter->inputs()[i];
    164     if (float* data = interpreter->typed_tensor<float>(input_index)) {
    165       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
    166         data[idx] = example.inputs[i].flat_data[idx];
    167       }
    168     } else if (int32_t* data =
    169                    interpreter->typed_tensor<int32_t>(input_index)) {
    170       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
    171         data[idx] = example.inputs[i].flat_data[idx];
    172       }
    173     } else if (int64_t* data =
    174                    interpreter->typed_tensor<int64_t>(input_index)) {
    175       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
    176         data[idx] = example.inputs[i].flat_data[idx];
    177       }
    178     } else {
    179       fprintf(stderr, "input[%zu] was not float or int data\n", i);
    180       return kTfLiteError;
    181     }
    182   }
    183   return kTfLiteOk;
    184 }
    185 
    186 TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
    187                           const Example& example) {
    188   constexpr double kRelativeThreshold = 1e-2f;
    189   constexpr double kAbsoluteThreshold = 1e-4f;
    190 
    191   ErrorReporter* context = DefaultErrorReporter();
    192   int model_outputs = interpreter->outputs().size();
    193   TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
    194   for (size_t i = 0; i < interpreter->outputs().size(); i++) {
    195     int output_index = interpreter->outputs()[i];
    196     if (const float* data = interpreter->typed_tensor<float>(output_index)) {
    197       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
    198         float computed = data[idx];
    199         float reference = example.outputs[0].flat_data[idx];
    200         float diff = std::abs(computed - reference);
    201         bool error_is_large = false;
    202         // For very small numbers, try absolute error, otherwise go with
    203         // relative.
    204         if (std::abs(reference) < kRelativeThreshold) {
    205           error_is_large = (diff > kAbsoluteThreshold);
    206         } else {
    207           error_is_large = (diff > kRelativeThreshold * std::abs(reference));
    208         }
    209         if (error_is_large) {
    210           fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
    211                   i, idx, data[idx], reference);
    212           return kTfLiteError;
    213         }
    214       }
    215       fprintf(stderr, "\n");
    216     } else if (const int32_t* data =
    217                    interpreter->typed_tensor<int32_t>(output_index)) {
    218       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
    219         int32_t computed = data[idx];
    220         int32_t reference = example.outputs[0].flat_data[idx];
    221         if (std::abs(computed - reference) > 0) {
    222           fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
    223                   i, idx, computed, reference);
    224           return kTfLiteError;
    225         }
    226       }
    227       fprintf(stderr, "\n");
    228     } else if (const int64_t* data =
    229                    interpreter->typed_tensor<int64_t>(output_index)) {
    230       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
    231         int64_t computed = data[idx];
    232         int64_t reference = example.outputs[0].flat_data[idx];
    233         if (std::abs(computed - reference) > 0) {
    234           fprintf(stderr,
    235                   "output[%zu][%zu] did not match %" PRId64
    236                   " vs reference %" PRId64 "\n",
    237                   i, idx, computed, reference);
    238           return kTfLiteError;
    239         }
    240       }
    241       fprintf(stderr, "\n");
    242     } else {
    243       fprintf(stderr, "output[%zu] was not float or int data\n", i);
    244       return kTfLiteError;
    245     }
    246   }
    247   return kTfLiteOk;
    248 }
    249 
    250 // Process an 'invoke' message, triggering execution of the test runner, as
    251 // well as verification of outputs. An 'invoke' message looks like:
    252 //   invoke {
    253 //     id: xyz
    254 //     input: 1,2,1,1,1,2,3,4
    255 //     output: 4,5,6
    256 //   }
    257 class Invoke : public Message {
    258  public:
    259   explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
    260     expected_inputs_ = test_runner->GetInputs();
    261     expected_outputs_ = test_runner->GetOutputs();
    262   }
    263 
    264   void SetField(const std::string& name, const std::string& value) override {
    265     if (name == "id") {
    266       test_runner_->SetInvocationId(value);
    267     } else if (name == "input") {
    268       if (expected_inputs_.empty()) {
    269         return test_runner_->Invalidate("Too many inputs");
    270       }
    271       test_runner_->SetInput(*expected_inputs_.begin(), value);
    272       expected_inputs_.erase(expected_inputs_.begin());
    273     } else if (name == "output") {
    274       if (expected_outputs_.empty()) {
    275         return test_runner_->Invalidate("Too many outputs");
    276       }
    277       test_runner_->SetExpectation(*expected_outputs_.begin(), value);
    278       expected_outputs_.erase(expected_outputs_.begin());
    279     }
    280   }
    281   void Finish() override {
    282     test_runner_->Invoke();
    283     test_runner_->CheckResults();
    284   }
    285 
    286  private:
    287   std::vector<int> expected_inputs_;
    288   std::vector<int> expected_outputs_;
    289 
    290   TestRunner* test_runner_;
    291 };
    292 
    293 // Process an 'reshape' message, triggering resizing of the input tensors via
    294 // the test runner. A 'reshape' message looks like:
    295 //   reshape {
    296 //     input: 1,2,1,1,1,2,3,4
    297 //   }
    298 class Reshape : public Message {
    299  public:
    300   explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
    301     expected_inputs_ = test_runner->GetInputs();
    302   }
    303 
    304   void SetField(const std::string& name, const std::string& value) override {
    305     if (name == "input") {
    306       if (expected_inputs_.empty()) {
    307         return test_runner_->Invalidate("Too many inputs to reshape");
    308       }
    309       test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
    310       expected_inputs_.erase(expected_inputs_.begin());
    311     }
    312   }
    313 
    314  private:
    315   std::vector<int> expected_inputs_;
    316   TestRunner* test_runner_;
    317 };
    318 
    319 // This is the top-level message in a test file.
    320 class TestData : public Message {
    321  public:
    322   explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {}
    323 
    324   void SetField(const std::string& name, const std::string& value) override {
    325     if (name == "load_model") {
    326       test_runner_->LoadModel(value);
    327     } else if (name == "init_state") {
    328       test_runner_->AllocateTensors();
    329       for (int id : Split<int>(value, ",")) {
    330         test_runner_->ResetTensor(id);
    331       }
    332     }
    333   }
    334   Message* AddChild(const std::string& s) override {
    335     if (s == "invoke") {
    336       test_runner_->AllocateTensors();
    337       return Store(new Invoke(test_runner_));
    338     } else if (s == "reshape") {
    339       return Store(new Reshape(test_runner_));
    340     }
    341     return nullptr;
    342   }
    343 
    344  private:
    345   TestRunner* test_runner_;
    346 };
    347 
    348 bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) {
    349   TestData test_data(test_runner);
    350   Message::Read(input, &test_data);
    351   return test_runner->IsValid() && test_runner->GetOverallSuccess();
    352 }
    353 
    354 }  // namespace testing
    355 }  // namespace tflite
    356