1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite RNN op. 16 17 #include <iomanip> 18 #include <vector> 19 20 #include <gmock/gmock.h> 21 #include <gtest/gtest.h> 22 #include "tensorflow/contrib/lite/interpreter.h" 23 #include "tensorflow/contrib/lite/kernels/register.h" 24 #include "tensorflow/contrib/lite/kernels/test_util.h" 25 #include "tensorflow/contrib/lite/model.h" 26 27 namespace tflite { 28 namespace { 29 30 using ::testing::ElementsAreArray; 31 32 static float rnn_input[] = { 33 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 34 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, 35 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, 36 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, 37 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, 38 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, 39 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, 40 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, 41 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, 42 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, 43 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, 44 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, 45 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, 46 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, 47 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, 48 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, 49 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, 50 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, 51 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 52 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, 53 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, 54 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, 55 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, 56 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, 57 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 58 0.93455386, -0.6324693, -0.083922029}; 59 60 static float rnn_golden_output[] = { 61 0.496726, 0, 0.965996, 0, 0.0584254, 0, 62 0, 0.12315, 0, 0, 0.612266, 0.456601, 63 0, 0.52286, 1.16099, 0.0291232, 64 65 0, 0, 0.524901, 0, 0, 0, 66 0, 1.02116, 0, 1.35762, 0, 0.356909, 67 0.436415, 0.0355727, 0, 0, 68 69 0, 0, 0, 0.262335, 0, 0, 70 0, 1.33992, 0, 2.9739, 0, 0, 71 1.31914, 2.66147, 0, 0, 72 73 0.942568, 0, 0, 0, 0.025507, 0, 74 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, 75 0.8158, 1.21805, 0.586239, 0.25427, 76 77 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 78 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, 79 0, 1.22031, 1.30117, 0.495867, 80 81 0.222187, 0, 0.72725, 0, 0.767003, 0, 82 0, 0.147835, 0, 0, 0, 0.608758, 83 0.469394, 0.00720298, 0.927537, 0, 84 85 0.856974, 0.424257, 0, 0, 0.937329, 0, 86 0, 0, 0.476425, 0, 0.566017, 0.418462, 87 0.141911, 0.996214, 1.13063, 0, 88 89 0.967899, 0, 0, 0, 0.0831304, 0, 90 0, 1.00378, 0, 0, 0, 1.44818, 91 1.01768, 0.943891, 0.502745, 0, 92 93 0.940135, 0, 0, 0, 0, 0, 94 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, 95 1.30225, 1.59644, 0.70222, 0, 96 97 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 98 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, 99 0.0454298, 0.300267, 0.562784, 0.395095, 100 101 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 102 0, 0, 0, 0.735363, 0.0759267, 1.91017, 103 0.941888, 0, 0, 0, 104 105 0, 0, 1.5909, 0, 0, 0, 106 0, 0.5755, 0, 0.184687, 0, 1.56296, 107 0.625285, 0, 0, 0, 108 109 0, 0, 0.0857888, 0, 0, 0, 110 0, 0.488383, 0.252786, 0, 0, 0, 111 1.02817, 1.85665, 0, 0, 112 113 0.00981836, 0, 1.06371, 0, 0, 0, 114 0, 0, 0, 0.290445, 0.316406, 0, 115 0.304161, 1.25079, 0.0707152, 0, 116 117 0.986264, 0.309201, 0, 0, 0, 0, 118 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, 119 0.524981, 1.92076, 2.07013, 0.333244, 120 121 0.415153, 0.210318, 0, 0, 0, 0, 122 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 123 0.628881, 3.58099, 1.49974, 0}; 124 125 class RNNOpModel : public SingleOpModel { 126 public: 127 RNNOpModel(int batches, int units, int size) 128 : batches_(batches), units_(units), input_size_(size) { 129 input_ = AddInput(TensorType_FLOAT32); 130 weights_ = AddInput(TensorType_FLOAT32); 131 recurrent_weights_ = AddInput(TensorType_FLOAT32); 132 bias_ = AddInput(TensorType_FLOAT32); 133 hidden_state_ = AddOutput(TensorType_FLOAT32); 134 output_ = AddOutput(TensorType_FLOAT32); 135 SetBuiltinOp( 136 BuiltinOperator_RNN, BuiltinOptions_RNNOptions, 137 CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); 138 BuildInterpreter({{batches_, input_size_}, 139 {units_, input_size_}, 140 {units_, units_}, 141 {units_}}); 142 } 143 144 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } 145 146 void SetWeights(std::initializer_list<float> f) { 147 PopulateTensor(weights_, f); 148 } 149 150 void SetRecurrentWeights(std::initializer_list<float> f) { 151 PopulateTensor(recurrent_weights_, f); 152 } 153 154 void SetInput(std::initializer_list<float> data) { 155 PopulateTensor(input_, data); 156 } 157 158 void SetInput(int offset, float* begin, float* end) { 159 PopulateTensor(input_, offset, begin, end); 160 } 161 162 void ResetHiddenState() { 163 const int zero_buffer_size = units_ * batches_; 164 std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); 165 memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); 166 PopulateTensor(hidden_state_, 0, zero_buffer.get(), 167 zero_buffer.get() + zero_buffer_size); 168 } 169 170 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 171 172 int input_size() { return input_size_; } 173 int num_units() { return units_; } 174 int num_batches() { return batches_; } 175 176 private: 177 int input_; 178 int weights_; 179 int recurrent_weights_; 180 int bias_; 181 int hidden_state_; 182 int output_; 183 184 int batches_; 185 int units_; 186 int input_size_; 187 }; 188 189 TEST(FullyConnectedOpTest, BlackBoxTest) { 190 RNNOpModel rnn(2, 16, 8); 191 rnn.SetWeights( 192 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 193 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 194 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, 195 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, 196 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, 197 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, 198 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, 199 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 200 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 201 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, 202 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, 203 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, 204 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, 205 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, 206 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 207 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 208 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, 209 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, 210 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, 211 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, 212 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, 213 0.277308, 0.415818}); 214 215 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, 216 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, 217 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, 218 -0.37609905}); 219 220 rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 221 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 222 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 223 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 224 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 225 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 226 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 227 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 228 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 229 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 230 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 231 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 232 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 233 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 234 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 235 0.1}); 236 237 rnn.ResetHiddenState(); 238 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / 239 (rnn.input_size() * rnn.num_batches()); 240 241 for (int i = 0; i < input_sequence_size; i++) { 242 float* batch_start = rnn_input + i * rnn.input_size(); 243 float* batch_end = batch_start + rnn.input_size(); 244 rnn.SetInput(0, batch_start, batch_end); 245 rnn.SetInput(rnn.input_size(), batch_start, batch_end); 246 247 rnn.Invoke(); 248 249 float* golden_start = rnn_golden_output + i * rnn.num_units(); 250 float* golden_end = golden_start + rnn.num_units(); 251 std::vector<float> expected; 252 expected.insert(expected.end(), golden_start, golden_end); 253 expected.insert(expected.end(), golden_start, golden_end); 254 255 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 256 } 257 } 258 259 } // namespace 260 } // namespace tflite 261 262 int main(int argc, char** argv) { 263 ::tflite::LogToStderr(); 264 ::testing::InitGoogleTest(&argc, argv); 265 return RUN_ALL_TESTS(); 266 } 267