1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/testing/tflite_driver.h" 16 17 #include <iostream> 18 19 #include "tensorflow/contrib/lite/testing/split.h" 20 21 namespace tflite { 22 namespace testing { 23 24 namespace { 25 26 // Returns the value in the given position in a tensor. 27 template <typename T> 28 T Value(const TfLitePtrUnion& data, int index); 29 template <> 30 float Value(const TfLitePtrUnion& data, int index) { 31 return data.f[index]; 32 } 33 template <> 34 int32_t Value(const TfLitePtrUnion& data, int index) { 35 return data.i32[index]; 36 } 37 template <> 38 int64_t Value(const TfLitePtrUnion& data, int index) { 39 return data.i64[index]; 40 } 41 template <> 42 uint8_t Value(const TfLitePtrUnion& data, int index) { 43 return data.uint8[index]; 44 } 45 46 template <typename T> 47 void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) { 48 T* input_ptr = reinterpret_cast<T*>(data->raw); 49 for (const T& v : values) { 50 *input_ptr = v; 51 ++input_ptr; 52 } 53 } 54 55 } // namespace 56 57 class TfLiteDriver::Expectation { 58 public: 59 Expectation() { data_.raw = nullptr; } 60 ~Expectation() { delete[] data_.raw; } 61 template <typename T> 62 void SetData(const string& csv_values) { 63 const auto& values = testing::Split<T>(csv_values, ","); 64 data_.raw = new char[values.size() * sizeof(T)]; 65 SetTensorData(values, &data_); 66 } 67 68 bool Check(bool verbose, const TfLiteTensor& tensor) { 69 switch (tensor.type) { 70 case kTfLiteFloat32: 71 return TypedCheck<float>(verbose, tensor); 72 case kTfLiteInt32: 73 return TypedCheck<int32_t>(verbose, tensor); 74 case kTfLiteInt64: 75 return TypedCheck<int64_t>(verbose, tensor); 76 case kTfLiteUInt8: 77 return TypedCheck<uint8_t>(verbose, tensor); 78 default: 79 fprintf(stderr, "Unsupported type %d in Check\n", tensor.type); 80 return false; 81 } 82 } 83 84 private: 85 template <typename T> 86 bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { 87 // TODO(ahentz): must find a way to configure the tolerance. 88 constexpr double kRelativeThreshold = 1e-2f; 89 constexpr double kAbsoluteThreshold = 1e-4f; 90 91 int tensor_size = tensor.bytes / sizeof(T); 92 93 bool good_output = true; 94 for (int i = 0; i < tensor_size; ++i) { 95 float computed = Value<T>(tensor.data, i); 96 float reference = Value<T>(data_, i); 97 float diff = std::abs(computed - reference); 98 bool error_is_large = false; 99 // For very small numbers, try absolute error, otherwise go with 100 // relative. 101 if (std::abs(reference) < kRelativeThreshold) { 102 error_is_large = (diff > kAbsoluteThreshold); 103 } else { 104 error_is_large = (diff > kRelativeThreshold * std::abs(reference)); 105 } 106 if (error_is_large) { 107 good_output = false; 108 if (verbose) { 109 std::cerr << " index " << i << ": got " << computed 110 << ", but expected " << reference << std::endl; 111 } 112 } 113 } 114 return good_output; 115 } 116 117 TfLitePtrUnion data_; 118 }; 119 120 TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} 121 TfLiteDriver::~TfLiteDriver() {} 122 123 void TfLiteDriver::AllocateTensors() { 124 if (must_allocate_tensors_) { 125 if (interpreter_->AllocateTensors() != kTfLiteOk) { 126 Invalidate("Failed to allocate tensors"); 127 return; 128 } 129 must_allocate_tensors_ = false; 130 } 131 } 132 133 void TfLiteDriver::LoadModel(const string& bin_file_path) { 134 if (!IsValid()) return; 135 std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; 136 137 model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str()); 138 if (!model_) { 139 Invalidate("Failed to mmap model " + bin_file_path); 140 return; 141 } 142 ops::builtin::BuiltinOpResolver builtins; 143 InterpreterBuilder(*model_, builtins)(&interpreter_); 144 if (!interpreter_) { 145 Invalidate("Failed build interpreter"); 146 return; 147 } 148 149 must_allocate_tensors_ = true; 150 } 151 152 void TfLiteDriver::ResetTensor(int id) { 153 if (!IsValid()) return; 154 auto* tensor = interpreter_->tensor(id); 155 memset(tensor->data.raw, 0, tensor->bytes); 156 } 157 158 void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) { 159 if (!IsValid()) return; 160 if (interpreter_->ResizeInputTensor( 161 id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) { 162 Invalidate("Failed to resize input tensor " + std::to_string(id)); 163 return; 164 } 165 must_allocate_tensors_ = true; 166 } 167 168 void TfLiteDriver::SetInput(int id, const string& csv_values) { 169 if (!IsValid()) return; 170 auto* tensor = interpreter_->tensor(id); 171 switch (tensor->type) { 172 case kTfLiteFloat32: { 173 const auto& values = testing::Split<float>(csv_values, ","); 174 if (!CheckSizes<float>(tensor->bytes, values.size())) return; 175 SetTensorData(values, &tensor->data); 176 break; 177 } 178 case kTfLiteInt32: { 179 const auto& values = testing::Split<int32_t>(csv_values, ","); 180 if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return; 181 SetTensorData(values, &tensor->data); 182 break; 183 } 184 case kTfLiteInt64: { 185 const auto& values = testing::Split<int64_t>(csv_values, ","); 186 if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return; 187 SetTensorData(values, &tensor->data); 188 break; 189 } 190 case kTfLiteUInt8: { 191 const auto& values = testing::Split<uint8_t>(csv_values, ","); 192 if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return; 193 SetTensorData(values, &tensor->data); 194 break; 195 } 196 default: 197 fprintf(stderr, "Unsupported type %d in SetInput\n", tensor->type); 198 Invalidate("Unsupported tensor data type"); 199 return; 200 } 201 } 202 203 void TfLiteDriver::SetExpectation(int id, const string& csv_values) { 204 if (!IsValid()) return; 205 auto* tensor = interpreter_->tensor(id); 206 if (expected_output_.count(id) != 0) { 207 fprintf(stderr, "Overriden expectation for tensor %d\n", id); 208 Invalidate("Overriden expectation"); 209 } 210 expected_output_[id].reset(new Expectation); 211 switch (tensor->type) { 212 case kTfLiteFloat32: 213 expected_output_[id]->SetData<float>(csv_values); 214 break; 215 case kTfLiteInt32: 216 expected_output_[id]->SetData<int32_t>(csv_values); 217 break; 218 case kTfLiteInt64: 219 expected_output_[id]->SetData<int64_t>(csv_values); 220 break; 221 case kTfLiteUInt8: 222 expected_output_[id]->SetData<uint8_t>(csv_values); 223 break; 224 default: 225 fprintf(stderr, "Unsupported type %d in SetExpectation\n", tensor->type); 226 Invalidate("Unsupported tensor data type"); 227 return; 228 } 229 } 230 231 void TfLiteDriver::Invoke() { 232 if (!IsValid()) return; 233 if (interpreter_->Invoke() != kTfLiteOk) { 234 Invalidate("Failed to invoke interpreter"); 235 } 236 } 237 238 bool TfLiteDriver::CheckResults() { 239 if (!IsValid()) return false; 240 bool success = true; 241 for (const auto& p : expected_output_) { 242 int id = p.first; 243 auto* tensor = interpreter_->tensor(id); 244 if (!p.second->Check(/*verbose=*/false, *tensor)) { 245 // Do not invalidate anything here. Instead, simply output the 246 // differences and return false. Invalidating would prevent all 247 // subsequent invocations from running.. 248 std::cerr << "There were errors in invocation '" << GetInvocationId() 249 << "', output tensor '" << id << "':" << std::endl; 250 p.second->Check(/*verbose=*/true, *tensor); 251 success = false; 252 SetOverallSuccess(false); 253 } 254 } 255 expected_output_.clear(); 256 return success; 257 } 258 259 } // namespace testing 260 } // namespace tflite 261