Home | History | Annotate | Download | only in example
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/example/example_parser_configuration.h"
     16 
     17 #include <memory>
     18 
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/io/path.h"
     24 #include "tensorflow/core/platform/protobuf.h"
     25 #include "tensorflow/core/platform/test.h"
     26 #include "tensorflow/core/public/session_options.h"
     27 #include "tensorflow/core/util/example_proto_helper.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
     33   TF_CHECK_OK(ReadFileToString(env, filename, output));
     34 }
     35 
     36 std::unique_ptr<Session> CreateSession() {
     37   SessionOptions options;
     38   (*options.config.mutable_device_count())["CPU"] = 2;
     39   return std::unique_ptr<Session>(NewSession(options));
     40 }
     41 
     42 class ExtractExampleParserConfigurationTest : public ::testing::Test {
     43  protected:
     44   void SetUp() override {
     45     string proto_string;
     46     string filename =
     47         io::JoinPath(testing::TensorFlowSrcRoot(),
     48                      "core/example/testdata/parse_example_graph_def.pbtxt");
     49     ReadFileToStringOrDie(Env::Default(), filename, &proto_string);
     50     protobuf::TextFormat::ParseFromString(proto_string, &graph_def_);
     51     session_ = CreateSession();
     52     TF_CHECK_OK(session_->Create(graph_def_));
     53   }
     54 
     55   NodeDef* parse_example_node() {
     56     for (auto& node : *graph_def_.mutable_node()) {
     57       if (node.name() == "ParseExample/ParseExample") {
     58         return &node;
     59       }
     60     }
     61     return nullptr;
     62   }
     63 
     64   GraphDef graph_def_;
     65   std::unique_ptr<Session> session_;
     66 };
     67 
     68 TEST_F(ExtractExampleParserConfigurationTest, OpNotFound) {
     69   std::vector<FixedLenFeature> dense_vec;
     70   std::vector<VarLenFeature> sparse_vec;
     71   Status status = ExtractExampleParserConfiguration(
     72       graph_def_, "BlarseExample/ParseExample", session_.get(), &dense_vec,
     73       &sparse_vec);
     74 
     75   EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
     76 }
     77 
     78 TEST_F(ExtractExampleParserConfigurationTest, InconsistentAttrNsparse) {
     79   std::vector<FixedLenFeature> dense_vec;
     80   std::vector<VarLenFeature> sparse_vec;
     81 
     82   NodeDef* node = parse_example_node();
     83   auto mutable_attr = node->mutable_attr();
     84   (*mutable_attr)["Nsparse"].set_i(3);
     85 
     86   Status status = ExtractExampleParserConfiguration(
     87       graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec,
     88       &sparse_vec);
     89 
     90   EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
     91 }
     92 
     93 TEST_F(ExtractExampleParserConfigurationTest, InconsistentAttrNdense) {
     94   std::vector<FixedLenFeature> dense_vec;
     95   std::vector<VarLenFeature> sparse_vec;
     96 
     97   NodeDef* node = parse_example_node();
     98   auto mutable_attr = node->mutable_attr();
     99   (*mutable_attr)["Ndense"].set_i(2);
    100 
    101   Status status = ExtractExampleParserConfiguration(
    102       graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec,
    103       &sparse_vec);
    104 
    105   EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
    106 }
    107 
    108 TEST_F(ExtractExampleParserConfigurationTest, Basic) {
    109   std::vector<FixedLenFeature> dense_vec;
    110   std::vector<VarLenFeature> sparse_vec;
    111   Status status = ExtractExampleParserConfiguration(
    112       graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec,
    113       &sparse_vec);
    114 
    115   EXPECT_EQ(Status::OK(), status);
    116   EXPECT_EQ(2, sparse_vec.size());
    117   EXPECT_EQ(3, dense_vec.size());
    118 
    119   EXPECT_EQ("sf0", sparse_vec[0].key);
    120   EXPECT_EQ(DT_STRING, sparse_vec[0].dtype);
    121   EXPECT_EQ("ParseExample/ParseExample:0",
    122             sparse_vec[0].indices_output_tensor_name);
    123   EXPECT_EQ("ParseExample/ParseExample:2",
    124             sparse_vec[0].values_output_tensor_name);
    125   EXPECT_EQ("ParseExample/ParseExample:4",
    126             sparse_vec[0].shapes_output_tensor_name);
    127 
    128   EXPECT_EQ("sf1", sparse_vec[1].key);
    129   EXPECT_EQ(DT_STRING, sparse_vec[1].dtype);
    130   EXPECT_EQ("ParseExample/ParseExample:1",
    131             sparse_vec[1].indices_output_tensor_name);
    132   EXPECT_EQ("ParseExample/ParseExample:3",
    133             sparse_vec[1].values_output_tensor_name);
    134   EXPECT_EQ("ParseExample/ParseExample:5",
    135             sparse_vec[1].shapes_output_tensor_name);
    136 
    137   EXPECT_EQ("x", dense_vec[0].key);
    138   EXPECT_EQ(DT_FLOAT, dense_vec[0].dtype);
    139   EXPECT_EQ("ParseExample/ParseExample:6",
    140             dense_vec[0].values_output_tensor_name);
    141 
    142   EXPECT_EQ("y", dense_vec[1].key);
    143   EXPECT_EQ(DT_FLOAT, dense_vec[1].dtype);
    144   EXPECT_EQ("ParseExample/ParseExample:7",
    145             dense_vec[1].values_output_tensor_name);
    146 
    147   EXPECT_EQ("z", dense_vec[2].key);
    148   EXPECT_EQ(DT_FLOAT, dense_vec[2].dtype);
    149   EXPECT_EQ("ParseExample/ParseExample:8",
    150             dense_vec[2].values_output_tensor_name);
    151 }
    152 
    153 static const char kExampleParseConfigurationProto[] = R"( feature_map {
    154   key: "x"
    155   value {
    156     fixed_len_feature {
    157       dtype: DT_FLOAT
    158       shape {
    159         dim {
    160           size: 1
    161         }
    162       }
    163       default_value {
    164         dtype: DT_FLOAT
    165         tensor_shape {
    166           dim {
    167             size: 1
    168           }
    169         }
    170         float_val: 33.0
    171       }
    172       values_output_tensor_name: "ParseExample/ParseExample:3"
    173     }
    174   }
    175 }
    176 feature_map {
    177   key: "y"
    178   value {
    179     var_len_feature {
    180       dtype: DT_STRING
    181       values_output_tensor_name: "ParseExample/ParseExample:1"
    182       indices_output_tensor_name: "ParseExample/ParseExample:0"
    183       shapes_output_tensor_name: "ParseExample/ParseExample:2"
    184     }
    185   }
    186 }
    187 )";
    188 
    189 class ExampleParserConfigurationProtoToFeatureVectorsTest
    190     : public ::testing::Test {
    191  protected:
    192   void SetUp() override {
    193     CHECK(protobuf::TextFormat::ParseFromString(kExampleParseConfigurationProto,
    194                                                 &config_proto_));
    195   }
    196   ExampleParserConfiguration config_proto_;
    197 };
    198 
    199 TEST_F(ExampleParserConfigurationProtoToFeatureVectorsTest, Basic) {
    200   std::vector<FixedLenFeature> fixed_len_features;
    201   std::vector<VarLenFeature> var_len_features;
    202   TF_ASSERT_OK(ExampleParserConfigurationProtoToFeatureVectors(
    203       config_proto_, &fixed_len_features, &var_len_features));
    204   ASSERT_EQ(1, fixed_len_features.size());
    205   ASSERT_EQ(1, var_len_features.size());
    206 
    207   const FixedLenFeature& f = fixed_len_features[0];
    208   ASSERT_EQ(DT_FLOAT, f.dtype);
    209   ASSERT_EQ("x", f.key);
    210   ASSERT_EQ("ParseExample/ParseExample:3", f.values_output_tensor_name);
    211 
    212   TensorShape expected_shape({1});
    213   ASSERT_EQ(expected_shape.dims(), f.shape.dims());
    214   ASSERT_EQ(1, f.shape.dim_size(0));
    215 
    216   Tensor expected_default(DT_FLOAT, TensorShape({1}));
    217   test::FillIota<float>(&expected_default, 33.0);
    218   test::ExpectTensorEqual<float>(expected_default, f.default_value);
    219 
    220   const VarLenFeature& v = var_len_features[0];
    221   ASSERT_EQ(DT_STRING, v.dtype);
    222   ASSERT_EQ("ParseExample/ParseExample:0", v.indices_output_tensor_name);
    223   ASSERT_EQ("ParseExample/ParseExample:1", v.values_output_tensor_name);
    224   ASSERT_EQ("ParseExample/ParseExample:2", v.shapes_output_tensor_name);
    225 }
    226 
    227 }  // namespace
    228 }  // namespace tensorflow
    229