1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/testing/tf_driver.h" 16 17 #include <fstream> 18 #include <iostream> 19 20 #include "tensorflow/contrib/lite/testing/join.h" 21 #include "tensorflow/contrib/lite/testing/split.h" 22 #include "tensorflow/core/lib/gtl/array_slice.h" 23 24 namespace tflite { 25 namespace testing { 26 27 namespace { 28 29 tensorflow::Tensor CreateTensor(const tensorflow::DataType type, 30 const std::vector<int64_t>& dim) { 31 tensorflow::TensorShape shape{gtl::ArraySlice<int64>{ 32 reinterpret_cast<const int64*>(dim.data()), dim.size()}}; 33 return {type, shape}; 34 } 35 36 template <typename T> 37 void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) { 38 auto data = tensor->flat<T>(); 39 40 const auto& values = testing::Split<T>(csv_values, ","); 41 for (int i = 0; i < values.size(); i++) { 42 data(i) = values[i]; 43 } 44 } 45 46 template <typename T> 47 void FillTensorWithZeros(tensorflow::Tensor* tensor) { 48 auto data = tensor->flat<T>(); 49 for (int i = 0; i < tensor->NumElements(); i++) { 50 data(i) = 0; 51 } 52 } 53 54 template <typename T> 55 string TensorDataToCsvString(const tensorflow::Tensor& tensor) { 56 const auto& data = tensor.flat<T>(); 57 return Join(data.data(), data.size(), ","); 58 } 59 60 } // namespace 61 62 TfDriver::TfDriver(const std::vector<string>& input_layer, 63 const std::vector<string>& input_layer_type, 64 const std::vector<string>& input_layer_shape, 65 const std::vector<string>& output_layer) 66 : input_names_(input_layer), output_names_(output_layer) { 67 CHECK_EQ(input_layer.size(), input_layer_type.size()); 68 CHECK_EQ(input_layer.size(), input_layer_shape.size()); 69 70 input_ids_.resize(input_layer.size()); 71 input_tensors_.reserve(input_layer.size()); 72 input_types_.resize(input_layer.size()); 73 input_shapes_.resize(input_layer.size()); 74 for (int i = 0; i < input_layer.size(); i++) { 75 input_ids_[i] = i; 76 input_tensors_[input_layer[i]] = {}; 77 CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i])); 78 input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ","); 79 } 80 81 output_ids_.resize(output_layer.size()); 82 output_tensors_.reserve(output_layer.size()); 83 for (int i = 0; i < output_layer.size(); i++) { 84 output_ids_[i] = i; 85 } 86 } 87 88 void TfDriver::LoadModel(const string& bin_file_path) { 89 if (!IsValid()) return; 90 std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; 91 std::ifstream model(bin_file_path); 92 if (model.fail()) { 93 Invalidate("Failed to find the model"); 94 return; 95 } 96 97 tensorflow::GraphDef graphdef; 98 if (!graphdef.ParseFromIstream(&model)) { 99 Invalidate("Failed to parse tensorflow graphdef"); 100 return; 101 } 102 103 tensorflow::SessionOptions options; 104 session_.reset(tensorflow::NewSession(options)); 105 auto status = session_->Create(graphdef); 106 if (!status.ok()) { 107 Invalidate("Failed to create session"); 108 } 109 } 110 111 void TfDriver::SetInput(int id, const string& csv_values) { 112 if (!IsValid()) return; 113 114 auto tensor = CreateTensor(input_types_[id], input_shapes_[id]); 115 switch (input_types_[id]) { 116 case tensorflow::DT_FLOAT: { 117 FillTensorWithData<float>(&tensor, csv_values); 118 break; 119 } 120 case tensorflow::DT_INT32: { 121 FillTensorWithData<int32_t>(&tensor, csv_values); 122 break; 123 } 124 default: 125 fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]); 126 Invalidate("Unsupported tensor data type"); 127 return; 128 } 129 input_tensors_[input_names_[id]] = tensor; 130 } 131 132 void TfDriver::ResetTensor(int id) { 133 if (!IsValid()) return; 134 auto tensor = input_tensors_[input_names_[id]]; 135 switch (input_types_[id]) { 136 case tensorflow::DT_FLOAT: { 137 FillTensorWithZeros<float>(&tensor); 138 break; 139 } 140 case tensorflow::DT_INT32: { 141 FillTensorWithZeros<int32_t>(&tensor); 142 break; 143 } 144 default: 145 fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); 146 Invalidate("Unsupported tensor data type"); 147 return; 148 } 149 } 150 151 void TfDriver::ReshapeTensor(int id, const string& csv_values) { 152 input_shapes_[id] = Split<int64_t>(csv_values, ","); 153 input_tensors_[input_names_[id]] = 154 CreateTensor(input_types_[id], input_shapes_[id]); 155 ResetTensor(id); 156 } 157 158 string TfDriver::ReadOutput(int id) { 159 if (!IsValid()) return ""; 160 switch (output_tensors_[id].dtype()) { 161 case tensorflow::DT_FLOAT: 162 return TensorDataToCsvString<float>(output_tensors_[id]); 163 case tensorflow::DT_INT32: 164 return TensorDataToCsvString<int32_t>(output_tensors_[id]); 165 default: 166 fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); 167 Invalidate("Unsupported tensor data type"); 168 return ""; 169 } 170 } 171 172 void TfDriver::Invoke() { 173 if (!IsValid()) return; 174 auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, 175 output_names_, {}, &output_tensors_); 176 if (!status.ok()) { 177 Invalidate("Failed to invoke interpreter"); 178 } 179 } 180 181 } // namespace testing 182 } // namespace tflite 183