1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite Sequential RNN op. 16 17 #include <iomanip> 18 #include <vector> 19 20 #include <gmock/gmock.h> 21 #include <gtest/gtest.h> 22 #include "tensorflow/contrib/lite/interpreter.h" 23 #include "tensorflow/contrib/lite/kernels/register.h" 24 #include "tensorflow/contrib/lite/kernels/test_util.h" 25 #include "tensorflow/contrib/lite/model.h" 26 27 namespace tflite { 28 namespace { 29 30 using ::testing::ElementsAreArray; 31 32 static float rnn_input[] = { 33 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 34 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, 35 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, 36 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, 37 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, 38 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, 39 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, 40 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, 41 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, 42 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, 43 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, 44 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, 45 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, 46 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, 47 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, 48 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, 49 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, 50 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, 51 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 52 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, 53 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, 54 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, 55 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, 56 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, 57 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 58 0.93455386, -0.6324693, -0.083922029}; 59 60 static float rnn_golden_output[] = { 61 0.496726, 0, 0.965996, 0, 0.0584254, 0, 62 0, 0.12315, 0, 0, 0.612266, 0.456601, 63 0, 0.52286, 1.16099, 0.0291232, 64 65 0, 0, 0.524901, 0, 0, 0, 66 0, 1.02116, 0, 1.35762, 0, 0.356909, 67 0.436415, 0.0355727, 0, 0, 68 69 0, 0, 0, 0.262335, 0, 0, 70 0, 1.33992, 0, 2.9739, 0, 0, 71 1.31914, 2.66147, 0, 0, 72 73 0.942568, 0, 0, 0, 0.025507, 0, 74 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, 75 0.8158, 1.21805, 0.586239, 0.25427, 76 77 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 78 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, 79 0, 1.22031, 1.30117, 0.495867, 80 81 0.222187, 0, 0.72725, 0, 0.767003, 0, 82 0, 0.147835, 0, 0, 0, 0.608758, 83 0.469394, 0.00720298, 0.927537, 0, 84 85 0.856974, 0.424257, 0, 0, 0.937329, 0, 86 0, 0, 0.476425, 0, 0.566017, 0.418462, 87 0.141911, 0.996214, 1.13063, 0, 88 89 0.967899, 0, 0, 0, 0.0831304, 0, 90 0, 1.00378, 0, 0, 0, 1.44818, 91 1.01768, 0.943891, 0.502745, 0, 92 93 0.940135, 0, 0, 0, 0, 0, 94 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, 95 1.30225, 1.59644, 0.70222, 0, 96 97 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 98 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, 99 0.0454298, 0.300267, 0.562784, 0.395095, 100 101 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 102 0, 0, 0, 0.735363, 0.0759267, 1.91017, 103 0.941888, 0, 0, 0, 104 105 0, 0, 1.5909, 0, 0, 0, 106 0, 0.5755, 0, 0.184687, 0, 1.56296, 107 0.625285, 0, 0, 0, 108 109 0, 0, 0.0857888, 0, 0, 0, 110 0, 0.488383, 0.252786, 0, 0, 0, 111 1.02817, 1.85665, 0, 0, 112 113 0.00981836, 0, 1.06371, 0, 0, 0, 114 0, 0, 0, 0.290445, 0.316406, 0, 115 0.304161, 1.25079, 0.0707152, 0, 116 117 0.986264, 0.309201, 0, 0, 0, 0, 118 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, 119 0.524981, 1.92076, 2.07013, 0.333244, 120 121 0.415153, 0.210318, 0, 0, 0, 0, 122 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 123 0.628881, 3.58099, 1.49974, 0}; 124 125 class UnidirectionalRNNOpModel : public SingleOpModel { 126 public: 127 UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, 128 bool time_major) 129 : batches_(batches), 130 sequence_len_(sequence_len), 131 units_(units), 132 input_size_(size) { 133 input_ = AddInput(TensorType_FLOAT32); 134 weights_ = AddInput(TensorType_FLOAT32); 135 recurrent_weights_ = AddInput(TensorType_FLOAT32); 136 bias_ = AddInput(TensorType_FLOAT32); 137 hidden_state_ = AddOutput(TensorType_FLOAT32); 138 output_ = AddOutput(TensorType_FLOAT32); 139 SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 140 BuiltinOptions_SequenceRNNOptions, 141 CreateSequenceRNNOptions(builder_, time_major, 142 ActivationFunctionType_RELU) 143 .Union()); 144 if (time_major) { 145 BuildInterpreter({{sequence_len_, batches_, input_size_}, 146 {units_, input_size_}, 147 {units_, units_}, 148 {units_}}); 149 } else { 150 BuildInterpreter({{batches_, sequence_len_, input_size_}, 151 {units_, input_size_}, 152 {units_, units_}, 153 {units_}}); 154 } 155 } 156 157 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } 158 159 void SetWeights(std::initializer_list<float> f) { 160 PopulateTensor(weights_, f); 161 } 162 163 void SetRecurrentWeights(std::initializer_list<float> f) { 164 PopulateTensor(recurrent_weights_, f); 165 } 166 167 void SetInput(std::initializer_list<float> data) { 168 PopulateTensor(input_, data); 169 } 170 171 void SetInput(int offset, float* begin, float* end) { 172 PopulateTensor(input_, offset, begin, end); 173 } 174 175 void ResetHiddenState() { 176 const int zero_buffer_size = units_ * batches_; 177 std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); 178 memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); 179 PopulateTensor(hidden_state_, 0, zero_buffer.get(), 180 zero_buffer.get() + zero_buffer_size); 181 } 182 183 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 184 185 int input_size() { return input_size_; } 186 int num_units() { return units_; } 187 int num_batches() { return batches_; } 188 int sequence_len() { return sequence_len_; } 189 190 private: 191 int input_; 192 int weights_; 193 int recurrent_weights_; 194 int bias_; 195 int hidden_state_; 196 int output_; 197 198 int batches_; 199 int sequence_len_; 200 int units_; 201 int input_size_; 202 }; 203 204 // TODO(mirkov): add another test which directly compares to TF once TOCO 205 // supports the conversion from dynamic_rnn with BasicRNNCell. 206 TEST(FullyConnectedOpTest, BlackBoxTest) { 207 UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, 208 /*units=*/16, /*size=*/8, /*time_major=*/false); 209 rnn.SetWeights( 210 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 211 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 212 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, 213 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, 214 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, 215 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, 216 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, 217 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 218 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 219 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, 220 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, 221 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, 222 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, 223 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, 224 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 225 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 226 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, 227 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, 228 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, 229 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, 230 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, 231 0.277308, 0.415818}); 232 233 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, 234 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, 235 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, 236 -0.37609905}); 237 238 rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 239 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 240 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 241 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 242 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 243 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 244 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 245 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 246 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 247 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 248 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 250 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 251 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 252 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253 0.1}); 254 255 rnn.ResetHiddenState(); 256 const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); 257 float* batch_start = rnn_input; 258 float* batch_end = batch_start + input_sequence_size; 259 rnn.SetInput(0, batch_start, batch_end); 260 rnn.SetInput(input_sequence_size, batch_start, batch_end); 261 262 rnn.Invoke(); 263 264 float* golden_start = rnn_golden_output; 265 float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); 266 std::vector<float> expected; 267 expected.insert(expected.end(), golden_start, golden_end); 268 expected.insert(expected.end(), golden_start, golden_end); 269 270 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 271 } 272 273 TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { 274 UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, 275 /*units=*/16, /*size=*/8, /*time_major=*/true); 276 rnn.SetWeights( 277 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 278 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 279 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, 280 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, 281 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, 282 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, 283 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, 284 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 285 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 286 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, 287 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, 288 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, 289 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, 290 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, 291 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 292 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 293 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, 294 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, 295 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, 296 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, 297 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, 298 0.277308, 0.415818}); 299 300 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, 301 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, 302 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, 303 -0.37609905}); 304 305 rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 306 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 307 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 308 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 309 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 310 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 311 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 312 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 313 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 314 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 315 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 316 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 317 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 318 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 319 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 320 0.1}); 321 322 rnn.ResetHiddenState(); 323 for (int i = 0; i < rnn.sequence_len(); i++) { 324 float* batch_start = rnn_input + i * rnn.input_size(); 325 float* batch_end = batch_start + rnn.input_size(); 326 // The two batches are identical. 327 rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); 328 rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); 329 } 330 331 rnn.Invoke(); 332 333 std::vector<float> expected; 334 for (int i = 0; i < rnn.sequence_len(); i++) { 335 float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); 336 float* golden_batch_end = golden_batch_start + rnn.num_units(); 337 expected.insert(expected.end(), golden_batch_start, golden_batch_end); 338 expected.insert(expected.end(), golden_batch_start, golden_batch_end); 339 } 340 341 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 342 } 343 344 } // namespace 345 } // namespace tflite 346 347 int main(int argc, char** argv) { 348 // On Linux, add: tflite::LogToStderr(); 349 ::testing::InitGoogleTest(&argc, argv); 350 return RUN_ALL_TESTS(); 351 } 352