Home | History | Annotate | Download | only in tests
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/cached-features.h"
     18 
     19 #include "gmock/gmock.h"
     20 #include "gtest/gtest.h"
     21 
     22 namespace libtextclassifier {
     23 namespace {
     24 
     25 class TestingCachedFeatures : public CachedFeatures {
     26  public:
     27   using CachedFeatures::CachedFeatures;
     28   using CachedFeatures::RemapV0FeatureVector;
     29 };
     30 
     31 TEST(CachedFeaturesTest, Simple) {
     32   std::vector<Token> tokens;
     33   tokens.push_back(Token());
     34   tokens.push_back(Token());
     35   tokens.push_back(Token("Hello", 0, 1));
     36   tokens.push_back(Token("World", 1, 2));
     37   tokens.push_back(Token("today!", 2, 3));
     38   tokens.push_back(Token());
     39   tokens.push_back(Token());
     40 
     41   std::vector<std::vector<int>> sparse_features(tokens.size());
     42   for (int i = 0; i < sparse_features.size(); ++i) {
     43     sparse_features[i].push_back(i);
     44   }
     45   std::vector<std::vector<float>> dense_features(tokens.size());
     46   for (int i = 0; i < dense_features.size(); ++i) {
     47     dense_features[i].push_back(-i);
     48   }
     49 
     50   TestingCachedFeatures feature_extractor(
     51       tokens, /*context_size=*/2, sparse_features, dense_features,
     52       [](const std::vector<int>& sparse_features,
     53          const std::vector<float>& dense_features, float* features) {
     54         features[0] = sparse_features[0];
     55         features[1] = sparse_features[0];
     56         features[2] = dense_features[0];
     57         features[3] = dense_features[0];
     58         features[4] = 123;
     59         return true;
     60       },
     61       5);
     62 
     63   VectorSpan<float> features;
     64   VectorSpan<Token> output_tokens;
     65   EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
     66   for (int i = 0; i < 5; i++) {
     67     EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
     68     EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
     69     EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
     70     EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
     71     EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
     72   }
     73 }
     74 
     75 TEST(CachedFeaturesTest, InvalidInput) {
     76   std::vector<Token> tokens;
     77   tokens.push_back(Token());
     78   tokens.push_back(Token());
     79   tokens.push_back(Token("Hello", 0, 1));
     80   tokens.push_back(Token("World", 1, 2));
     81   tokens.push_back(Token("today!", 2, 3));
     82   tokens.push_back(Token());
     83   tokens.push_back(Token());
     84 
     85   std::vector<std::vector<int>> sparse_features(tokens.size());
     86   std::vector<std::vector<float>> dense_features(tokens.size());
     87 
     88   TestingCachedFeatures feature_extractor(
     89       tokens, /*context_size=*/2, sparse_features, dense_features,
     90       [](const std::vector<int>& sparse_features,
     91          const std::vector<float>& dense_features,
     92          float* features) { return true; },
     93       /*feature_vector_size=*/5);
     94 
     95   VectorSpan<float> features;
     96   VectorSpan<Token> output_tokens;
     97   EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
     98   EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
     99   EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
    100   EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
    101   EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
    102   EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
    103   EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
    104 }
    105 
    106 TEST(CachedFeaturesTest, RemapV0FeatureVector) {
    107   std::vector<Token> tokens;
    108   tokens.push_back(Token());
    109   tokens.push_back(Token());
    110   tokens.push_back(Token("Hello", 0, 1));
    111   tokens.push_back(Token("World", 1, 2));
    112   tokens.push_back(Token("today!", 2, 3));
    113   tokens.push_back(Token());
    114   tokens.push_back(Token());
    115 
    116   std::vector<std::vector<int>> sparse_features(tokens.size());
    117   std::vector<std::vector<float>> dense_features(tokens.size());
    118 
    119   TestingCachedFeatures feature_extractor(
    120       tokens, /*context_size=*/2, sparse_features, dense_features,
    121       [](const std::vector<int>& sparse_features,
    122          const std::vector<float>& dense_features,
    123          float* features) { return true; },
    124       /*feature_vector_size=*/5);
    125 
    126   std::vector<float> features_orig(5 * 5);
    127   for (int i = 0; i < features_orig.size(); i++) {
    128     features_orig[i] = i;
    129   }
    130   VectorSpan<float> features;
    131 
    132   feature_extractor.SetV0FeatureMode(0);
    133   features = VectorSpan<float>(features_orig);
    134   feature_extractor.RemapV0FeatureVector(&features);
    135   EXPECT_EQ(
    136       std::vector<float>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
    137                           13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
    138       std::vector<float>(features.begin(), features.end()));
    139 
    140   feature_extractor.SetV0FeatureMode(2);
    141   features = VectorSpan<float>(features_orig);
    142   feature_extractor.RemapV0FeatureVector(&features);
    143   EXPECT_EQ(std::vector<float>({0, 1, 5, 6,  10, 11, 15, 16, 20, 21, 2,  3, 4,
    144                                 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
    145             std::vector<float>(features.begin(), features.end()));
    146 }
    147 
    148 }  // namespace
    149 }  // namespace libtextclassifier
    150