Home | History | Annotate | Download | only in example
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/example/feature_util.h"
     16 
     17 #include <vector>
     18 
     19 #include "tensorflow/core/example/example.pb.h"
     20 #include "tensorflow/core/platform/test.h"
     21 #include "tensorflow/core/platform/types.h"
     22 
     23 namespace tensorflow {
     24 namespace {
     25 
     26 const float kTolerance = 1e-5;
     27 
     28 TEST(GetFeatureValuesInt64Test, ReadsASingleValue) {
     29   Example example;
     30   (*example.mutable_features()->mutable_feature())["tag"]
     31       .mutable_int64_list()
     32       ->add_value(42);
     33 
     34   auto tag = GetFeatureValues<protobuf_int64>("tag", example);
     35 
     36   ASSERT_EQ(1, tag.size());
     37   EXPECT_EQ(42, tag.Get(0));
     38 }
     39 
     40 TEST(GetFeatureValuesInt64Test, ReadsASingleValueFromFeature) {
     41   Feature feature;
     42   feature.mutable_int64_list()->add_value(42);
     43 
     44   auto values = GetFeatureValues<protobuf_int64>(feature);
     45 
     46   ASSERT_EQ(1, values.size());
     47   EXPECT_EQ(42, values.Get(0));
     48 }
     49 
     50 TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
     51   Example example;
     52 
     53   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
     54 
     55   ASSERT_EQ(1,
     56             example.features().feature().at("tag").int64_list().value_size());
     57   EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0));
     58 }
     59 
     60 TEST(GetFeatureValuesInt64Test, WritesASingleValueToFeature) {
     61   Feature feature;
     62 
     63   GetFeatureValues<protobuf_int64>(&feature)->Add(42);
     64 
     65   ASSERT_EQ(1, feature.int64_list().value_size());
     66   EXPECT_EQ(42, feature.int64_list().value(0));
     67 }
     68 
     69 TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) {
     70   Example example;
     71   ASSERT_FALSE(HasFeature("tag", example));
     72 
     73   GetFeatureValues<protobuf_int64>("tag", &example)->Add(0);
     74 
     75   EXPECT_TRUE(HasFeature("tag", example));
     76 }
     77 
     78 TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) {
     79   Example example;
     80 
     81   GetFeatureValues<float>("tag", &example)->Add(3.14);
     82   ASSERT_FALSE(HasFeature<protobuf_int64>("tag", example));
     83 
     84   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
     85 
     86   EXPECT_TRUE(HasFeature<protobuf_int64>("tag", example));
     87   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
     88   ASSERT_EQ(1, tag_ro.size());
     89   EXPECT_EQ(42, tag_ro.Get(0));
     90 }
     91 
     92 TEST(GetFeatureValuesInt64Test, CopyIterableToAField) {
     93   Example example;
     94   std::vector<int> values{1, 2, 3};
     95 
     96   std::copy(values.begin(), values.end(),
     97             protobuf::RepeatedFieldBackInserter(
     98                 GetFeatureValues<protobuf_int64>("tag", &example)));
     99 
    100   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
    101   ASSERT_EQ(3, tag_ro.size());
    102   EXPECT_EQ(1, tag_ro.Get(0));
    103   EXPECT_EQ(2, tag_ro.Get(1));
    104   EXPECT_EQ(3, tag_ro.Get(2));
    105 }
    106 
    107 TEST(GetFeatureValuesFloatTest, ReadsASingleValueFromFeature) {
    108   Feature feature;
    109   feature.mutable_float_list()->add_value(3.14);
    110 
    111   auto values = GetFeatureValues<float>(feature);
    112 
    113   ASSERT_EQ(1, values.size());
    114   EXPECT_NEAR(3.14, values.Get(0), kTolerance);
    115 }
    116 
    117 TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
    118   Example example;
    119   (*example.mutable_features()->mutable_feature())["tag"]
    120       .mutable_float_list()
    121       ->add_value(3.14);
    122 
    123   auto tag = GetFeatureValues<float>("tag", example);
    124 
    125   ASSERT_EQ(1, tag.size());
    126   EXPECT_NEAR(3.14, tag.Get(0), kTolerance);
    127 }
    128 
    129 TEST(GetFeatureValuesFloatTest, WritesASingleValueToFeature) {
    130   Feature feature;
    131 
    132   GetFeatureValues<float>(&feature)->Add(3.14);
    133 
    134   ASSERT_EQ(1, feature.float_list().value_size());
    135   EXPECT_NEAR(3.14, feature.float_list().value(0), kTolerance);
    136 }
    137 
    138 TEST(GetFeatureValuesFloatTest, WritesASingleValue) {
    139   Example example;
    140 
    141   GetFeatureValues<float>("tag", &example)->Add(3.14);
    142 
    143   ASSERT_EQ(1,
    144             example.features().feature().at("tag").float_list().value_size());
    145   EXPECT_NEAR(3.14,
    146               example.features().feature().at("tag").float_list().value(0),
    147               kTolerance);
    148 }
    149 
    150 TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
    151   Example example;
    152 
    153   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
    154   ASSERT_FALSE(HasFeature<float>("tag", example));
    155 
    156   GetFeatureValues<float>("tag", &example)->Add(3.14);
    157 
    158   EXPECT_TRUE(HasFeature<float>("tag", example));
    159   auto tag_ro = GetFeatureValues<float>("tag", example);
    160   ASSERT_EQ(1, tag_ro.size());
    161   EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
    162 }
    163 
    164 TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistenceForDeprecatedMethod) {
    165   Example example;
    166 
    167   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
    168   ASSERT_FALSE(ExampleHasFeature<float>("tag", example));
    169 
    170   GetFeatureValues<float>("tag", &example)->Add(3.14);
    171 
    172   EXPECT_TRUE(ExampleHasFeature<float>("tag", example));
    173   auto tag_ro = GetFeatureValues<float>("tag", example);
    174   ASSERT_EQ(1, tag_ro.size());
    175   EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
    176 }
    177 
    178 TEST(GetFeatureValuesStringTest, ReadsASingleValueFromFeature) {
    179   Feature feature;
    180   feature.mutable_bytes_list()->add_value("FOO");
    181 
    182   auto values = GetFeatureValues<string>(feature);
    183 
    184   ASSERT_EQ(1, values.size());
    185   EXPECT_EQ("FOO", values.Get(0));
    186 }
    187 
    188 TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
    189   Example example;
    190   (*example.mutable_features()->mutable_feature())["tag"]
    191       .mutable_bytes_list()
    192       ->add_value("FOO");
    193 
    194   auto tag = GetFeatureValues<string>("tag", example);
    195 
    196   ASSERT_EQ(1, tag.size());
    197   EXPECT_EQ("FOO", tag.Get(0));
    198 }
    199 
    200 TEST(GetFeatureValuesStringTest, WritesASingleValueToFeature) {
    201   Feature feature;
    202 
    203   *GetFeatureValues<string>(&feature)->Add() = "FOO";
    204 
    205   ASSERT_EQ(1, feature.bytes_list().value_size());
    206   EXPECT_EQ("FOO", feature.bytes_list().value(0));
    207 }
    208 
    209 TEST(GetFeatureValuesStringTest, WritesASingleValue) {
    210   Example example;
    211 
    212   *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
    213 
    214   ASSERT_EQ(1,
    215             example.features().feature().at("tag").bytes_list().value_size());
    216   EXPECT_EQ("FOO",
    217             example.features().feature().at("tag").bytes_list().value(0));
    218 }
    219 
    220 TEST(GetFeatureValuesStringTest, CheckTypedFieldExistence) {
    221   Example example;
    222 
    223   GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
    224   ASSERT_FALSE(HasFeature<string>("tag", example));
    225 
    226   *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
    227 
    228   EXPECT_TRUE(HasFeature<string>("tag", example));
    229   auto tag_ro = GetFeatureValues<string>("tag", example);
    230   ASSERT_EQ(1, tag_ro.size());
    231   EXPECT_EQ("FOO", tag_ro.Get(0));
    232 }
    233 
    234 TEST(AppendFeatureValuesTest, FloatValuesFromContainer) {
    235   Example example;
    236 
    237   std::vector<double> values{1.1, 2.2, 3.3};
    238   AppendFeatureValues(values, "tag", &example);
    239 
    240   auto tag_ro = GetFeatureValues<float>("tag", example);
    241   ASSERT_EQ(3, tag_ro.size());
    242   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
    243   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
    244   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
    245 }
    246 
    247 TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) {
    248   Example example;
    249 
    250   AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example);
    251 
    252   auto tag_ro = GetFeatureValues<float>("tag", example);
    253   ASSERT_EQ(3, tag_ro.size());
    254   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
    255   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
    256   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
    257 }
    258 
    259 TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) {
    260   Example example;
    261 
    262   std::vector<protobuf_int64> values{1, 2, 3};
    263   AppendFeatureValues(values, "tag", &example);
    264 
    265   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
    266   ASSERT_EQ(3, tag_ro.size());
    267   EXPECT_EQ(1, tag_ro.Get(0));
    268   EXPECT_EQ(2, tag_ro.Get(1));
    269   EXPECT_EQ(3, tag_ro.Get(2));
    270 }
    271 
    272 TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) {
    273   Example example;
    274 
    275   AppendFeatureValues({"FOO", "BAR", "BAZ"}, "tag", &example);
    276 
    277   auto tag_ro = GetFeatureValues<string>("tag", example);
    278   ASSERT_EQ(3, tag_ro.size());
    279   EXPECT_EQ("FOO", tag_ro.Get(0));
    280   EXPECT_EQ("BAR", tag_ro.Get(1));
    281   EXPECT_EQ("BAZ", tag_ro.Get(2));
    282 }
    283 
    284 TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) {
    285   Example example;
    286 
    287   string string1("FOO");
    288   string string2("BAR");
    289   string string3("BAZ");
    290 
    291   AppendFeatureValues({string1, string2, string3}, "tag", &example);
    292 
    293   auto tag_ro = GetFeatureValues<string>("tag", example);
    294   ASSERT_EQ(3, tag_ro.size());
    295   EXPECT_EQ("FOO", tag_ro.Get(0));
    296   EXPECT_EQ("BAR", tag_ro.Get(1));
    297   EXPECT_EQ("BAZ", tag_ro.Get(2));
    298 }
    299 
    300 TEST(SequenceExampleTest, ReadsASingleValueFromContext) {
    301   SequenceExample se;
    302   (*se.mutable_context()->mutable_feature())["tag"]
    303       .mutable_int64_list()
    304       ->add_value(42);
    305 
    306   auto values = GetFeatureValues<protobuf_int64>("tag", se.context());
    307 
    308   ASSERT_EQ(1, values.size());
    309   EXPECT_EQ(42, values.Get(0));
    310 }
    311 
    312 TEST(SequenceExampleTest, WritesASingleValueToContext) {
    313   SequenceExample se;
    314 
    315   GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
    316 
    317   ASSERT_EQ(1, se.context().feature().at("tag").int64_list().value_size());
    318   EXPECT_EQ(42, se.context().feature().at("tag").int64_list().value(0));
    319 }
    320 
    321 TEST(SequenceExampleTest, AppendFeatureValuesToContextSingleArg) {
    322   SequenceExample se;
    323 
    324   AppendFeatureValues({1.1, 2.2, 3.3}, "tag", se.mutable_context());
    325 
    326   auto tag_ro = GetFeatureValues<float>("tag", se.context());
    327   ASSERT_EQ(3, tag_ro.size());
    328   EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
    329   EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
    330   EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
    331 }
    332 
    333 TEST(SequenceExampleTest, CheckTypedFieldExistence) {
    334   SequenceExample se;
    335 
    336   GetFeatureValues<float>("tag", se.mutable_context())->Add(3.14);
    337   ASSERT_FALSE(HasFeature<protobuf_int64>("tag", se.context()));
    338 
    339   GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
    340 
    341   EXPECT_TRUE(HasFeature<protobuf_int64>("tag", se.context()));
    342   auto tag_ro = GetFeatureValues<protobuf_int64>("tag", se.context());
    343   ASSERT_EQ(1, tag_ro.size());
    344   EXPECT_EQ(42, tag_ro.Get(0));
    345 }
    346 
    347 TEST(SequenceExampleTest, ReturnsExistingFeatureLists) {
    348   SequenceExample se;
    349   (*se.mutable_feature_lists()->mutable_feature_list())["tag"]
    350       .mutable_feature()
    351       ->Add();
    352 
    353   auto feature = GetFeatureList("tag", se);
    354 
    355   ASSERT_EQ(1, feature.size());
    356 }
    357 
    358 TEST(SequenceExampleTest, CreatesNewFeatureLists) {
    359   SequenceExample se;
    360 
    361   GetFeatureList("tag", &se)->Add();
    362 
    363   EXPECT_EQ(1, se.feature_lists().feature_list().at("tag").feature_size());
    364 }
    365 
    366 TEST(SequenceExampleTest, CheckFeatureListExistence) {
    367   SequenceExample se;
    368   ASSERT_FALSE(HasFeatureList("tag", se));
    369 
    370   GetFeatureList("tag", &se)->Add();
    371 
    372   ASSERT_TRUE(HasFeatureList("tag", se));
    373 }
    374 
    375 TEST(SequenceExampleTest, AppendFeatureValuesWithInitializerList) {
    376   SequenceExample se;
    377 
    378   AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context());
    379   AppendFeatureValues({"cam1-0", "cam2-0"},
    380                       GetFeatureList("images", &se)->Add());
    381   AppendFeatureValues({"cam1-1", "cam2-2"},
    382                       GetFeatureList("images", &se)->Add());
    383 
    384   EXPECT_EQ(se.DebugString(),
    385             "context {\n"
    386             "  feature {\n"
    387             "    key: \"ids\"\n"
    388             "    value {\n"
    389             "      int64_list {\n"
    390             "        value: 1\n"
    391             "        value: 2\n"
    392             "        value: 3\n"
    393             "      }\n"
    394             "    }\n"
    395             "  }\n"
    396             "}\n"
    397             "feature_lists {\n"
    398             "  feature_list {\n"
    399             "    key: \"images\"\n"
    400             "    value {\n"
    401             "      feature {\n"
    402             "        bytes_list {\n"
    403             "          value: \"cam1-0\"\n"
    404             "          value: \"cam2-0\"\n"
    405             "        }\n"
    406             "      }\n"
    407             "      feature {\n"
    408             "        bytes_list {\n"
    409             "          value: \"cam1-1\"\n"
    410             "          value: \"cam2-2\"\n"
    411             "        }\n"
    412             "      }\n"
    413             "    }\n"
    414             "  }\n"
    415             "}\n");
    416 }
    417 
    418 TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) {
    419   SequenceExample se;
    420 
    421   std::vector<float> readings{1.0, 2.5, 5.0};
    422   AppendFeatureValues(readings, GetFeatureList("movie_ratings", &se)->Add());
    423 
    424   EXPECT_EQ(se.DebugString(),
    425             "feature_lists {\n"
    426             "  feature_list {\n"
    427             "    key: \"movie_ratings\"\n"
    428             "    value {\n"
    429             "      feature {\n"
    430             "        float_list {\n"
    431             "          value: 1\n"
    432             "          value: 2.5\n"
    433             "          value: 5\n"
    434             "        }\n"
    435             "      }\n"
    436             "    }\n"
    437             "  }\n"
    438             "}\n");
    439 }
    440 
    441 }  // namespace
    442 }  // namespace tensorflow
    443