1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "RNN.h" 18 19 #include "NeuralNetworksWrapper.h" 20 #include "gmock/gmock-matchers.h" 21 #include "gtest/gtest.h" 22 23 namespace android { 24 namespace nn { 25 namespace wrapper { 26 27 using ::testing::Each; 28 using ::testing::FloatNear; 29 using ::testing::Matcher; 30 31 namespace { 32 33 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, 34 float max_abs_error = 1.e-5) { 35 std::vector<Matcher<float>> matchers; 36 matchers.reserve(values.size()); 37 for (const float& v : values) { 38 matchers.emplace_back(FloatNear(v, max_abs_error)); 39 } 40 return matchers; 41 } 42 43 static float rnn_input[] = { 44 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, 45 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, 46 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, 47 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, 48 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, 49 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, 50 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, 51 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, 52 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, 53 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, 54 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, 55 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, 56 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, 57 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, 58 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, 59 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, 60 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, 61 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, 62 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, 63 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, 64 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, 65 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, 66 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, 67 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, 68 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, 69 0.93455386, -0.6324693, -0.083922029}; 70 71 static float rnn_golden_output[] = { 72 0.496726, 0, 0.965996, 0, 0.0584254, 0, 73 0, 0.12315, 0, 0, 0.612266, 0.456601, 74 0, 0.52286, 1.16099, 0.0291232, 75 76 0, 0, 0.524901, 0, 0, 0, 77 0, 1.02116, 0, 1.35762, 0, 0.356909, 78 0.436415, 0.0355727, 0, 0, 79 80 0, 0, 0, 0.262335, 0, 0, 81 0, 1.33992, 0, 2.9739, 0, 0, 82 1.31914, 2.66147, 0, 0, 83 84 0.942568, 0, 0, 0, 0.025507, 0, 85 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, 86 0.8158, 1.21805, 0.586239, 0.25427, 87 88 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, 89 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, 90 0, 1.22031, 1.30117, 0.495867, 91 92 0.222187, 0, 0.72725, 0, 0.767003, 0, 93 0, 0.147835, 0, 0, 0, 0.608758, 94 0.469394, 0.00720298, 0.927537, 0, 95 96 0.856974, 0.424257, 0, 0, 0.937329, 0, 97 0, 0, 0.476425, 0, 0.566017, 0.418462, 98 0.141911, 0.996214, 1.13063, 0, 99 100 0.967899, 0, 0, 0, 0.0831304, 0, 101 0, 1.00378, 0, 0, 0, 1.44818, 102 1.01768, 0.943891, 0.502745, 0, 103 104 0.940135, 0, 0, 0, 0, 0, 105 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, 106 1.30225, 1.59644, 0.70222, 0, 107 108 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, 109 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, 110 0.0454298, 0.300267, 0.562784, 0.395095, 111 112 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, 113 0, 0, 0, 0.735363, 0.0759267, 1.91017, 114 0.941888, 0, 0, 0, 115 116 0, 0, 1.5909, 0, 0, 0, 117 0, 0.5755, 0, 0.184687, 0, 1.56296, 118 0.625285, 0, 0, 0, 119 120 0, 0, 0.0857888, 0, 0, 0, 121 0, 0.488383, 0.252786, 0, 0, 0, 122 1.02817, 1.85665, 0, 0, 123 124 0.00981836, 0, 1.06371, 0, 0, 0, 125 0, 0, 0, 0.290445, 0.316406, 0, 126 0.304161, 1.25079, 0.0707152, 0, 127 128 0.986264, 0.309201, 0, 0, 0, 0, 129 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, 130 0.524981, 1.92076, 2.07013, 0.333244, 131 132 0.415153, 0.210318, 0, 0, 0, 0, 133 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 134 0.628881, 3.58099, 1.49974, 0}; 135 136 } // anonymous namespace 137 138 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 139 ACTION(Input) \ 140 ACTION(Weights) \ 141 ACTION(RecurrentWeights) \ 142 ACTION(Bias) \ 143 ACTION(HiddenStateIn) 144 145 // For all output and intermediate states 146 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ 147 ACTION(HiddenStateOut) \ 148 ACTION(Output) 149 150 class BasicRNNOpModel { 151 public: 152 BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size) 153 : batches_(batches), 154 units_(units), 155 input_size_(size), 156 activation_(kActivationRelu) { 157 std::vector<uint32_t> inputs; 158 159 OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_}); 160 inputs.push_back(model_.addOperand(&InputTy)); 161 OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_}); 162 inputs.push_back(model_.addOperand(&WeightTy)); 163 OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_}); 164 inputs.push_back(model_.addOperand(&RecurrentWeightTy)); 165 OperandType BiasTy(Type::TENSOR_FLOAT32, {units_}); 166 inputs.push_back(model_.addOperand(&BiasTy)); 167 OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_}); 168 inputs.push_back(model_.addOperand(&HiddenStateTy)); 169 OperandType ActionParamTy(Type::INT32, {1}); 170 inputs.push_back(model_.addOperand(&ActionParamTy)); 171 172 std::vector<uint32_t> outputs; 173 174 outputs.push_back(model_.addOperand(&HiddenStateTy)); 175 OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_}); 176 outputs.push_back(model_.addOperand(&OutputTy)); 177 178 Input_.insert(Input_.end(), batches_ * input_size_, 0.f); 179 HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f); 180 HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f); 181 Output_.insert(Output_.end(), batches_ * units_, 0.f); 182 183 model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs); 184 model_.identifyInputsAndOutputs(inputs, outputs); 185 186 model_.finish(); 187 } 188 189 #define DefineSetter(X) \ 190 void Set##X(const std::vector<float>& f) { \ 191 X##_.insert(X##_.end(), f.begin(), f.end()); \ 192 } 193 194 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 195 196 #undef DefineSetter 197 198 void SetInput(int offset, float* begin, float* end) { 199 for (; begin != end; begin++, offset++) { 200 Input_[offset] = *begin; 201 } 202 } 203 204 void ResetHiddenState() { 205 std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f); 206 std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f); 207 } 208 209 const std::vector<float>& GetOutput() const { return Output_; } 210 211 uint32_t input_size() const { return input_size_; } 212 uint32_t num_units() const { return units_; } 213 uint32_t num_batches() const { return batches_; } 214 215 void Invoke() { 216 ASSERT_TRUE(model_.isValid()); 217 218 HiddenStateIn_.swap(HiddenStateOut_); 219 220 Compilation compilation(&model_); 221 compilation.finish(); 222 Execution execution(&compilation); 223 #define SetInputOrWeight(X) \ 224 ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), \ 225 sizeof(float) * X##_.size()), \ 226 Result::NO_ERROR); 227 228 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 229 230 #undef SetInputOrWeight 231 232 #define SetOutput(X) \ 233 ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), \ 234 sizeof(float) * X##_.size()), \ 235 Result::NO_ERROR); 236 237 FOR_ALL_OUTPUT_TENSORS(SetOutput); 238 239 #undef SetOutput 240 241 ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_, 242 sizeof(activation_)), 243 Result::NO_ERROR); 244 245 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 246 } 247 248 private: 249 Model model_; 250 251 const uint32_t batches_; 252 const uint32_t units_; 253 const uint32_t input_size_; 254 255 const int activation_; 256 257 #define DefineTensor(X) std::vector<float> X##_; 258 259 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); 260 FOR_ALL_OUTPUT_TENSORS(DefineTensor); 261 262 #undef DefineTensor 263 }; 264 265 TEST(RNNOpTest, BlackBoxTest) { 266 BasicRNNOpModel rnn(2, 16, 8); 267 rnn.SetWeights( 268 {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 269 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, 270 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, 271 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, 272 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, 273 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, 274 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, 275 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, 276 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, 277 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, 278 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, 279 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, 280 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, 281 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, 282 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, 283 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, 284 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, 285 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, 286 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, 287 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, 288 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, 289 0.277308, 0.415818}); 290 291 rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, 292 -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, 293 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, 294 -0.37609905}); 295 296 rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 297 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 298 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 299 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 300 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 301 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 302 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 303 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 304 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 305 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 306 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 307 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 308 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 309 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 310 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 311 0.1}); 312 313 rnn.ResetHiddenState(); 314 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / 315 (rnn.input_size() * rnn.num_batches()); 316 317 for (int i = 0; i < input_sequence_size; i++) { 318 float* batch_start = rnn_input + i * rnn.input_size(); 319 float* batch_end = batch_start + rnn.input_size(); 320 rnn.SetInput(0, batch_start, batch_end); 321 rnn.SetInput(rnn.input_size(), batch_start, batch_end); 322 323 rnn.Invoke(); 324 325 float* golden_start = rnn_golden_output + i * rnn.num_units(); 326 float* golden_end = golden_start + rnn.num_units(); 327 std::vector<float> expected; 328 expected.insert(expected.end(), golden_start, golden_end); 329 expected.insert(expected.end(), golden_start, golden_end); 330 331 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 332 } 333 } 334 335 } // namespace wrapper 336 } // namespace nn 337 } // namespace android 338