Home | History | Annotate | Download | only in testing
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/testing/tf_driver.h"
     16 
     17 #include <fstream>
     18 #include <iostream>
     19 
     20 #include "tensorflow/contrib/lite/testing/join.h"
     21 #include "tensorflow/contrib/lite/testing/split.h"
     22 #include "tensorflow/core/lib/gtl/array_slice.h"
     23 
     24 namespace tflite {
     25 namespace testing {
     26 
     27 namespace {
     28 
     29 tensorflow::Tensor CreateTensor(const tensorflow::DataType type,
     30                                 const std::vector<int64_t>& dim) {
     31   tensorflow::TensorShape shape{gtl::ArraySlice<int64>{
     32       reinterpret_cast<const int64*>(dim.data()), dim.size()}};
     33   return {type, shape};
     34 }
     35 
     36 template <typename T>
     37 void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) {
     38   auto data = tensor->flat<T>();
     39 
     40   const auto& values = testing::Split<T>(csv_values, ",");
     41   for (int i = 0; i < values.size(); i++) {
     42     data(i) = values[i];
     43   }
     44 }
     45 
     46 template <typename T>
     47 void FillTensorWithZeros(tensorflow::Tensor* tensor) {
     48   auto data = tensor->flat<T>();
     49   for (int i = 0; i < tensor->NumElements(); i++) {
     50     data(i) = 0;
     51   }
     52 }
     53 
     54 template <typename T>
     55 string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
     56   const auto& data = tensor.flat<T>();
     57   return Join(data.data(), data.size(), ",");
     58 }
     59 
     60 }  // namespace
     61 
     62 TfDriver::TfDriver(const std::vector<string>& input_layer,
     63                    const std::vector<string>& input_layer_type,
     64                    const std::vector<string>& input_layer_shape,
     65                    const std::vector<string>& output_layer)
     66     : input_names_(input_layer), output_names_(output_layer) {
     67   CHECK_EQ(input_layer.size(), input_layer_type.size());
     68   CHECK_EQ(input_layer.size(), input_layer_shape.size());
     69 
     70   input_ids_.resize(input_layer.size());
     71   input_tensors_.reserve(input_layer.size());
     72   input_types_.resize(input_layer.size());
     73   input_shapes_.resize(input_layer.size());
     74   for (int i = 0; i < input_layer.size(); i++) {
     75     input_ids_[i] = i;
     76     input_tensors_[input_layer[i]] = {};
     77     CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i]));
     78     input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ",");
     79   }
     80 
     81   output_ids_.resize(output_layer.size());
     82   output_tensors_.reserve(output_layer.size());
     83   for (int i = 0; i < output_layer.size(); i++) {
     84     output_ids_[i] = i;
     85   }
     86 }
     87 
     88 void TfDriver::LoadModel(const string& bin_file_path) {
     89   if (!IsValid()) return;
     90   std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
     91   std::ifstream model(bin_file_path);
     92   if (model.fail()) {
     93     Invalidate("Failed to find the model");
     94     return;
     95   }
     96 
     97   tensorflow::GraphDef graphdef;
     98   if (!graphdef.ParseFromIstream(&model)) {
     99     Invalidate("Failed to parse tensorflow graphdef");
    100     return;
    101   }
    102 
    103   tensorflow::SessionOptions options;
    104   session_.reset(tensorflow::NewSession(options));
    105   auto status = session_->Create(graphdef);
    106   if (!status.ok()) {
    107     Invalidate("Failed to create session");
    108   }
    109 }
    110 
    111 void TfDriver::SetInput(int id, const string& csv_values) {
    112   if (!IsValid()) return;
    113 
    114   auto tensor = CreateTensor(input_types_[id], input_shapes_[id]);
    115   switch (input_types_[id]) {
    116     case tensorflow::DT_FLOAT: {
    117       FillTensorWithData<float>(&tensor, csv_values);
    118       break;
    119     }
    120     case tensorflow::DT_INT32: {
    121       FillTensorWithData<int32_t>(&tensor, csv_values);
    122       break;
    123     }
    124     default:
    125       fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]);
    126       Invalidate("Unsupported tensor data type");
    127       return;
    128   }
    129   input_tensors_[input_names_[id]] = tensor;
    130 }
    131 
    132 void TfDriver::ResetTensor(int id) {
    133   if (!IsValid()) return;
    134   auto tensor = input_tensors_[input_names_[id]];
    135   switch (input_types_[id]) {
    136     case tensorflow::DT_FLOAT: {
    137       FillTensorWithZeros<float>(&tensor);
    138       break;
    139     }
    140     case tensorflow::DT_INT32: {
    141       FillTensorWithZeros<int32_t>(&tensor);
    142       break;
    143     }
    144     default:
    145       fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]);
    146       Invalidate("Unsupported tensor data type");
    147       return;
    148   }
    149 }
    150 
    151 void TfDriver::ReshapeTensor(int id, const string& csv_values) {
    152   input_shapes_[id] = Split<int64_t>(csv_values, ",");
    153   input_tensors_[input_names_[id]] =
    154       CreateTensor(input_types_[id], input_shapes_[id]);
    155   ResetTensor(id);
    156 }
    157 
    158 string TfDriver::ReadOutput(int id) {
    159   if (!IsValid()) return "";
    160   switch (output_tensors_[id].dtype()) {
    161     case tensorflow::DT_FLOAT:
    162       return TensorDataToCsvString<float>(output_tensors_[id]);
    163     case tensorflow::DT_INT32:
    164       return TensorDataToCsvString<int32_t>(output_tensors_[id]);
    165     default:
    166       fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]);
    167       Invalidate("Unsupported tensor data type");
    168       return "";
    169   }
    170 }
    171 
    172 void TfDriver::Invoke() {
    173   if (!IsValid()) return;
    174   auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
    175                               output_names_, {}, &output_tensors_);
    176   if (!status.ok()) {
    177     Invalidate("Failed to invoke interpreter");
    178   }
    179 }
    180 
    181 }  // namespace testing
    182 }  // namespace tflite
    183