1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite sparse lookup op. 16 17 #include <cmath> 18 #include <vector> 19 20 #include <gmock/gmock.h> 21 #include <gtest/gtest.h> 22 #include "tensorflow/lite/interpreter.h" 23 #include "tensorflow/lite/kernels/register.h" 24 #include "tensorflow/lite/kernels/test_util.h" 25 #include "tensorflow/lite/model.h" 26 27 namespace tflite { 28 namespace { 29 30 using ::testing::ElementsAreArray; 31 32 class EmbeddingLookupSparseOpModel : public SingleOpModel { 33 public: 34 EmbeddingLookupSparseOpModel(CombinerType type, 35 std::initializer_list<int> lookup_shape, 36 std::initializer_list<int> indices_shape, 37 std::initializer_list<int> dense_shape_shape, 38 std::initializer_list<int> value_shape) { 39 lookup_ = AddInput(TensorType_INT32); 40 indices_ = AddInput(TensorType_INT32); 41 dense_shape_ = AddInput(TensorType_INT32); 42 weights_ = AddInput(TensorType_FLOAT32); 43 value_ = AddInput(TensorType_FLOAT32); 44 output_ = AddOutput(TensorType_FLOAT32); 45 SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, 46 BuiltinOptions_EmbeddingLookupSparseOptions, 47 CreateEmbeddingLookupSparseOptions(builder_, type).Union()); 48 BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape, 49 lookup_shape, value_shape}); 50 } 51 52 void SetInput(std::initializer_list<int> lookup_data, 53 std::initializer_list<int> indices_data, 54 std::initializer_list<int> dense_shape_data, 55 std::initializer_list<float> weights_data) { 56 PopulateTensor(lookup_, lookup_data); 57 PopulateTensor(indices_, indices_data); 58 PopulateTensor(dense_shape_, dense_shape_data); 59 PopulateTensor(weights_, weights_data); 60 } 61 62 void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) { 63 TfLiteTensor* tensor = interpreter_->tensor(value_); 64 int rows = tensor->dims->data[0]; 65 int columns = tensor->dims->data[1]; 66 int features = tensor->dims->data[2]; 67 for (int i = 0; i < rows; i++) { 68 for (int j = 0; j < columns; j++) { 69 for (int k = 0; k < features; k++) { 70 tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); 71 } 72 } 73 } 74 } 75 76 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 77 78 private: 79 int lookup_; 80 int weights_; 81 int indices_; 82 int dense_shape_; 83 int value_; 84 int output_; 85 }; 86 87 TEST(EmbeddingLookupOpTest, SimpleTest) { 88 EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2}); 89 m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); 90 m.Set3DWeightMatrix( 91 [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); 92 m.Invoke(); 93 94 EXPECT_THAT(m.GetOutput(), 95 ElementsAreArray(ArrayFloatNear({ 96 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 97 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - 98 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0 99 }))); 100 } 101 102 TEST(EmbeddingLookupOpTest, SimpleTestMean) { 103 EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2}, 104 {4, 3, 2}); 105 m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); 106 m.Set3DWeightMatrix( 107 [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); 108 m.Invoke(); 109 110 EXPECT_THAT(m.GetOutput(), 111 ElementsAreArray(ArrayFloatNear({ 112 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 113 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - 114 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0 115 }))); 116 } 117 118 TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { 119 EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2}, 120 {4, 3, 2}); 121 m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); 122 m.Set3DWeightMatrix( 123 [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); 124 m.Invoke(); 125 126 EXPECT_THAT(m.GetOutput(), 127 ElementsAreArray(ArrayFloatNear({ 128 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 129 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - 130 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), 131 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), 132 7.20f / std::sqrt(20.0f), 133 7.26f / std::sqrt(20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * 134 // Row 3 + 4 * Row 0 135 }))); 136 } 137 138 TEST(EmbeddingLookupOpTest, Indices3DTest) { 139 EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2}); 140 m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2}, 141 {1.0, 2.0, 4.0}); 142 m.Set3DWeightMatrix( 143 [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); 144 m.Invoke(); 145 146 EXPECT_THAT(m.GetOutput(), 147 ElementsAreArray(ArrayFloatNear({ 148 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00, 149 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 150 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60, 151 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 152 }))); 153 } 154 155 } // namespace 156 } // namespace tflite 157 158 int main(int argc, char** argv) { 159 ::tflite::LogToStderr(); 160 ::testing::InitGoogleTest(&argc, argv); 161 return RUN_ALL_TESTS(); 162 } 163