1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "smartselect/cached-features.h" 18 19 #include "gmock/gmock.h" 20 #include "gtest/gtest.h" 21 22 namespace libtextclassifier { 23 namespace { 24 25 class TestingCachedFeatures : public CachedFeatures { 26 public: 27 using CachedFeatures::CachedFeatures; 28 using CachedFeatures::RemapV0FeatureVector; 29 }; 30 31 TEST(CachedFeaturesTest, Simple) { 32 std::vector<Token> tokens; 33 tokens.push_back(Token()); 34 tokens.push_back(Token()); 35 tokens.push_back(Token("Hello", 0, 1)); 36 tokens.push_back(Token("World", 1, 2)); 37 tokens.push_back(Token("today!", 2, 3)); 38 tokens.push_back(Token()); 39 tokens.push_back(Token()); 40 41 std::vector<std::vector<int>> sparse_features(tokens.size()); 42 for (int i = 0; i < sparse_features.size(); ++i) { 43 sparse_features[i].push_back(i); 44 } 45 std::vector<std::vector<float>> dense_features(tokens.size()); 46 for (int i = 0; i < dense_features.size(); ++i) { 47 dense_features[i].push_back(-i); 48 } 49 50 TestingCachedFeatures feature_extractor( 51 tokens, /*context_size=*/2, sparse_features, dense_features, 52 [](const std::vector<int>& sparse_features, 53 const std::vector<float>& dense_features, float* features) { 54 features[0] = sparse_features[0]; 55 features[1] = sparse_features[0]; 56 features[2] = dense_features[0]; 57 features[3] = dense_features[0]; 58 features[4] = 123; 59 return true; 60 }, 61 5); 62 63 VectorSpan<float> features; 64 VectorSpan<Token> output_tokens; 65 EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens)); 66 for (int i = 0; i < 5; i++) { 67 EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i; 68 EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i; 69 EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i; 70 EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i; 71 EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i; 72 } 73 } 74 75 TEST(CachedFeaturesTest, InvalidInput) { 76 std::vector<Token> tokens; 77 tokens.push_back(Token()); 78 tokens.push_back(Token()); 79 tokens.push_back(Token("Hello", 0, 1)); 80 tokens.push_back(Token("World", 1, 2)); 81 tokens.push_back(Token("today!", 2, 3)); 82 tokens.push_back(Token()); 83 tokens.push_back(Token()); 84 85 std::vector<std::vector<int>> sparse_features(tokens.size()); 86 std::vector<std::vector<float>> dense_features(tokens.size()); 87 88 TestingCachedFeatures feature_extractor( 89 tokens, /*context_size=*/2, sparse_features, dense_features, 90 [](const std::vector<int>& sparse_features, 91 const std::vector<float>& dense_features, 92 float* features) { return true; }, 93 /*feature_vector_size=*/5); 94 95 VectorSpan<float> features; 96 VectorSpan<Token> output_tokens; 97 EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens)); 98 EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens)); 99 EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens)); 100 EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens)); 101 EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens)); 102 EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens)); 103 EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens)); 104 } 105 106 TEST(CachedFeaturesTest, RemapV0FeatureVector) { 107 std::vector<Token> tokens; 108 tokens.push_back(Token()); 109 tokens.push_back(Token()); 110 tokens.push_back(Token("Hello", 0, 1)); 111 tokens.push_back(Token("World", 1, 2)); 112 tokens.push_back(Token("today!", 2, 3)); 113 tokens.push_back(Token()); 114 tokens.push_back(Token()); 115 116 std::vector<std::vector<int>> sparse_features(tokens.size()); 117 std::vector<std::vector<float>> dense_features(tokens.size()); 118 119 TestingCachedFeatures feature_extractor( 120 tokens, /*context_size=*/2, sparse_features, dense_features, 121 [](const std::vector<int>& sparse_features, 122 const std::vector<float>& dense_features, 123 float* features) { return true; }, 124 /*feature_vector_size=*/5); 125 126 std::vector<float> features_orig(5 * 5); 127 for (int i = 0; i < features_orig.size(); i++) { 128 features_orig[i] = i; 129 } 130 VectorSpan<float> features; 131 132 feature_extractor.SetV0FeatureMode(0); 133 features = VectorSpan<float>(features_orig); 134 feature_extractor.RemapV0FeatureVector(&features); 135 EXPECT_EQ( 136 std::vector<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 137 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}), 138 std::vector<float>(features.begin(), features.end())); 139 140 feature_extractor.SetV0FeatureMode(2); 141 features = VectorSpan<float>(features_orig); 142 feature_extractor.RemapV0FeatureVector(&features); 143 EXPECT_EQ(std::vector<float>({0, 1, 5, 6, 10, 11, 15, 16, 20, 21, 2, 3, 4, 144 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}), 145 std::vector<float>(features.begin(), features.end())); 146 } 147 148 } // namespace 149 } // namespace libtextclassifier 150