1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite SVDF op. 16 17 #include <iomanip> 18 #include <vector> 19 20 #include <gmock/gmock.h> 21 #include <gtest/gtest.h> 22 #include "tensorflow/lite/interpreter.h" 23 #include "tensorflow/lite/kernels/register.h" 24 #include "tensorflow/lite/kernels/test_util.h" 25 #include "tensorflow/lite/model.h" 26 27 namespace tflite { 28 namespace { 29 30 using ::testing::ElementsAreArray; 31 32 static float svdf_input[] = { 33 0.12609188, -0.46347019, -0.89598465, 34 0.35867718, 0.36897406, 0.73463392, 35 36 0.14278367, -1.64410412, -0.75222826, 37 -0.57290924, 0.12729003, 0.7567004, 38 39 0.49837467, 0.19278903, 0.26584083, 40 0.17660543, 0.52949083, -0.77931279, 41 42 -0.11186574, 0.13164264, -0.05349274, 43 -0.72674477, -0.5683046, 0.55900657, 44 45 -0.68892461, 0.37783599, 0.18263303, 46 -0.63690937, 0.44483393, -0.71817774, 47 48 -0.81299269, -0.86831826, 1.43940818, 49 -0.95760226, 1.82078898, 0.71135032, 50 51 -1.45006323, -0.82251364, -1.69082689, 52 -1.65087092, -1.89238167, 1.54172635, 53 54 0.03966608, -0.24936394, -0.77526885, 55 2.06740379, -1.51439476, 1.43768692, 56 57 0.11771342, -0.23761693, -0.65898693, 58 0.31088525, -1.55601168, -0.87661445, 59 60 -0.89477462, 1.67204106, -0.53235275, 61 -0.6230064, 0.29819036, 1.06939757, 62 }; 63 64 static float svdf_golden_output_rank_1[] = { 65 0.014899, -0.0517661, -0.143725, -0.00271883, 66 -0.03004015, 0.09565311, 0.1587342, 0.00784263, 67 68 0.068281, -0.162217, -0.152268, 0.00323521, 69 0.01582633, 0.03858774, -0.03001583, -0.02671271, 70 71 -0.0317821, -0.0333089, 0.0609602, 0.0333759, 72 -0.01432795, 0.05524484, 0.1101355, -0.02382665, 73 74 -0.00623099, -0.077701, -0.391193, -0.0136691, 75 -0.02333033, 0.02293761, 0.12338032, 0.04326871, 76 77 0.201551, -0.164607, -0.179462, -0.0592739, 78 0.01064911, -0.17503069, 0.07821996, -0.00224009, 79 80 0.0886511, -0.0875401, -0.269283, 0.0281379, 81 -0.02282338, 0.09741908, 0.32973239, 0.12281385, 82 83 -0.201174, -0.586145, -0.628624, -0.0330412, 84 0.24780814, -0.39304617, -0.22473189, 0.02589256, 85 86 -0.0839096, -0.299329, 0.108746, 0.109808, 87 0.10084175, -0.06416984, 0.28936723, 0.0026358, 88 89 0.419114, -0.237824, -0.422627, 0.175115, 90 -0.2314795, -0.18584411, -0.4228974, -0.12928449, 91 92 0.36726, -0.522303, -0.456502, -0.175475, 93 0.17012937, -0.34447709, 0.38505614, -0.28158101, 94 }; 95 96 static float svdf_golden_output_rank_2[] = { 97 -0.09623547, -0.10193135, 0.11083051, -0.0347917, 98 0.1141196, 0.12965347, -0.12652366, 0.01007236, 99 100 -0.16396809, -0.21247184, 0.11259045, -0.04156673, 101 0.10132131, -0.06143532, -0.00924693, 0.10084561, 102 103 0.01257364, 0.0506071, -0.19287863, -0.07162561, 104 -0.02033747, 0.22673416, 0.15487903, 0.02525555, 105 106 -0.1411963, -0.37054959, 0.01774767, 0.05867489, 107 0.09607603, -0.0141301, -0.08995658, 0.12867066, 108 109 -0.27142537, -0.16955489, 0.18521598, -0.12528358, 110 0.00331409, 0.11167502, 0.02218599, -0.07309391, 111 112 0.09593632, -0.28361851, -0.0773851, 0.17199151, 113 -0.00075242, 0.33691186, -0.1536046, 0.16572715, 114 115 -0.27916506, -0.27626723, 0.42615682, 0.3225764, 116 -0.37472126, -0.55655634, -0.05013514, 0.289112, 117 118 -0.24418658, 0.07540751, -0.1940318, -0.08911639, 119 0.00732617, 0.46737891, 0.26449674, 0.24888524, 120 121 -0.17225097, -0.54660404, -0.38795233, 0.08389944, 122 0.07736043, -0.28260678, 0.15666828, 1.14949894, 123 124 -0.57454878, -0.64704704, 0.73235172, -0.34616736, 125 0.21120001, -0.22927976, 0.02455296, -0.35906726, 126 }; 127 128 // Derived class of SingleOpModel, which is used to test SVDF TFLite op. 129 class BaseSVDFOpModel : public SingleOpModel { 130 public: 131 BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, 132 int rank, 133 TensorType weights_feature_type = TensorType_FLOAT32, 134 TensorType weights_time_type = TensorType_FLOAT32) 135 : batches_(batches), 136 units_(units), 137 input_size_(input_size), 138 memory_size_(memory_size), 139 rank_(rank) { 140 input_ = AddInput(TensorType_FLOAT32); 141 weights_feature_ = AddInput(weights_feature_type); 142 weights_time_ = AddInput(weights_time_type); 143 bias_ = AddNullInput(); 144 const int num_filters = units * rank; 145 activation_state_ = AddInput( 146 TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, 147 /*is_variable=*/true); 148 output_ = AddOutput(TensorType_FLOAT32); 149 SetBuiltinOp( 150 BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, 151 CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); 152 BuildInterpreter({ 153 {batches_, input_size_}, // input tensor 154 {units_ * rank, input_size_}, // weights_feature tensor 155 {units_ * rank, memory_size_}, // weights_time tensor 156 {units_}, // bias tensor 157 {batches, memory_size * num_filters} // activation_state tensor 158 }); 159 } 160 161 // Populates the weights_feature tensor. 162 void SetWeightsFeature(std::initializer_list<float> f) { 163 PopulateTensor(weights_feature_, f); 164 } 165 166 // Populates the weights_time tensor. 167 void SetWeightsTime(std::initializer_list<float> f) { 168 PopulateTensor(weights_time_, f); 169 } 170 171 // Populates the input tensor. 172 void SetInput(int offset, float* begin, float* end) { 173 PopulateTensor(input_, offset, begin, end); 174 } 175 176 // Extracts the output tensor from the SVDF op. 177 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 178 179 int input_size() { return input_size_; } 180 int num_units() { return units_; } 181 int num_batches() { return batches_; } 182 183 protected: 184 int input_; 185 int weights_feature_; 186 int weights_time_; 187 int bias_; 188 int activation_state_; 189 int output_; 190 191 int batches_; 192 int units_; 193 int input_size_; 194 int memory_size_; 195 int rank_; 196 }; 197 198 class SVDFOpModel : public BaseSVDFOpModel { 199 public: 200 using BaseSVDFOpModel::BaseSVDFOpModel; 201 }; 202 203 class HybridSVDFOpModel : public BaseSVDFOpModel { 204 public: 205 HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, 206 int rank, TensorType tensor_type) 207 : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, 208 tensor_type, tensor_type) { 209 tensor_type_ = tensor_type; 210 } 211 212 void SetWeights(int weights_idx, const std::vector<float>& f) { 213 if (tensor_type_ == TensorType_UINT8) { 214 SymmetricQuantizeAndPopulate(weights_idx, f); 215 } else { 216 SignedSymmetricQuantizeAndPopulate(weights_idx, f); 217 } 218 } 219 220 void SetWeightsFeature(std::initializer_list<float> f) { 221 SetWeights(weights_feature_, f); 222 } 223 224 void SetWeightsTime(std::initializer_list<float> f) { 225 SetWeights(weights_time_, f); 226 } 227 228 protected: 229 TensorType tensor_type_; 230 }; 231 232 class SVDFOpTest : public ::testing::Test { 233 protected: 234 void VerifyGoldens(float golden_input[], float golden_output[], 235 int golden_size, BaseSVDFOpModel* svdf, 236 float tolerance = 1e-5) { 237 const int svdf_num_batches = svdf->num_batches(); 238 const int svdf_input_size = svdf->input_size(); 239 const int svdf_num_units = svdf->num_units(); 240 const int input_sequence_size = 241 golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); 242 // Going over each input batch, setting the input tensor, invoking the SVDF 243 // op and checking the output with the expected golden values. 244 for (int i = 0; i < input_sequence_size; i++) { 245 float* batch_start = 246 golden_input + i * svdf_input_size * svdf_num_batches; 247 float* batch_end = batch_start + svdf_input_size * svdf_num_batches; 248 svdf->SetInput(0, batch_start, batch_end); 249 250 svdf->Invoke(); 251 252 const float* golden_start = 253 golden_output + i * svdf_num_units * svdf_num_batches; 254 const float* golden_end = 255 golden_start + svdf_num_units * svdf_num_batches; 256 std::vector<float> expected; 257 expected.insert(expected.end(), golden_start, golden_end); 258 259 EXPECT_THAT(svdf->GetOutput(), 260 ElementsAreArray(ArrayFloatNear(expected, tolerance))); 261 } 262 } 263 }; 264 265 TEST_F(SVDFOpTest, BlackBoxTestRank1) { 266 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 267 /*memory_size=*/10, /*rank=*/1); 268 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 269 0.22197971, 0.12416199, 0.27901134, 0.27557442, 270 0.3905206, -0.36137494, -0.06634006, -0.10640851}); 271 272 svdf.SetWeightsTime( 273 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 274 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 275 276 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 277 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 278 279 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 280 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 281 282 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 283 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); 284 285 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), 286 &svdf); 287 } 288 289 TEST_F(SVDFOpTest, BlackBoxTestRank2) { 290 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 291 /*memory_size=*/10, /*rank=*/2); 292 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 293 0.12416199, 0.15785322, 0.27901134, 0.3905206, 294 0.21931258, -0.36137494, -0.10640851, 0.31053296, 295 -0.36118156, -0.0976817, -0.36916667, 0.22197971, 296 0.15294972, 0.38031587, 0.27557442, 0.39635518, 297 -0.21580373, -0.06634006, -0.02702999, 0.27072677}); 298 299 svdf.SetWeightsTime( 300 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 301 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 302 303 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 304 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 305 306 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 307 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 308 309 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 310 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, 311 312 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, 313 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, 314 315 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, 316 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, 317 318 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, 319 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, 320 321 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, 322 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); 323 324 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), 325 &svdf); 326 } 327 328 TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { 329 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 330 /*memory_size=*/10, /*rank=*/1, TensorType_UINT8); 331 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 332 0.22197971, 0.12416199, 0.27901134, 0.27557442, 333 0.3905206, -0.36137494, -0.06634006, -0.10640851}); 334 335 svdf.SetWeightsTime( 336 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 337 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 338 339 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 340 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 341 342 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 343 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 344 345 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 346 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); 347 348 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), 349 &svdf, 350 /*tolerance=*/0.002945); 351 } 352 353 TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { 354 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 355 /*memory_size=*/10, /*rank=*/2, TensorType_UINT8); 356 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 357 0.12416199, 0.15785322, 0.27901134, 0.3905206, 358 0.21931258, -0.36137494, -0.10640851, 0.31053296, 359 -0.36118156, -0.0976817, -0.36916667, 0.22197971, 360 0.15294972, 0.38031587, 0.27557442, 0.39635518, 361 -0.21580373, -0.06634006, -0.02702999, 0.27072677}); 362 363 svdf.SetWeightsTime( 364 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 365 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 366 367 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 368 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 369 370 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 371 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 372 373 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 374 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, 375 376 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, 377 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, 378 379 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, 380 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, 381 382 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, 383 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, 384 385 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, 386 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); 387 388 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), 389 &svdf, 390 /*tolerance=*/0.00625109); 391 } 392 393 TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { 394 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 395 /*memory_size=*/10, /*rank=*/1, TensorType_INT8); 396 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 397 0.22197971, 0.12416199, 0.27901134, 0.27557442, 398 0.3905206, -0.36137494, -0.06634006, -0.10640851}); 399 400 svdf.SetWeightsTime( 401 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 402 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 403 404 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 405 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 406 407 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 408 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 409 410 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 411 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); 412 413 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), 414 &svdf, 415 /*tolerance=*/0.002945); 416 } 417 418 TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { 419 HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, 420 /*memory_size=*/10, /*rank=*/2, TensorType_INT8); 421 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 422 0.12416199, 0.15785322, 0.27901134, 0.3905206, 423 0.21931258, -0.36137494, -0.10640851, 0.31053296, 424 -0.36118156, -0.0976817, -0.36916667, 0.22197971, 425 0.15294972, 0.38031587, 0.27557442, 0.39635518, 426 -0.21580373, -0.06634006, -0.02702999, 0.27072677}); 427 428 svdf.SetWeightsTime( 429 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, 430 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, 431 432 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, 433 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, 434 435 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, 436 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, 437 438 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, 439 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, 440 441 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, 442 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, 443 444 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, 445 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, 446 447 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, 448 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, 449 450 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, 451 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); 452 453 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), 454 &svdf, 455 /*tolerance=*/0.00625109); 456 } 457 458 } // namespace 459 } // namespace tflite 460 461 int main(int argc, char** argv) { 462 ::tflite::LogToStderr(); 463 ::testing::InitGoogleTest(&argc, argv); 464 return RUN_ALL_TESTS(); 465 } 466