1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/example/example_parser_configuration.h" 16 17 #include <memory> 18 19 #include "tensorflow/core/framework/attr_value.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/io/path.h" 24 #include "tensorflow/core/platform/protobuf.h" 25 #include "tensorflow/core/platform/test.h" 26 #include "tensorflow/core/public/session_options.h" 27 #include "tensorflow/core/util/example_proto_helper.h" 28 29 namespace tensorflow { 30 namespace { 31 32 void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { 33 TF_CHECK_OK(ReadFileToString(env, filename, output)); 34 } 35 36 std::unique_ptr<Session> CreateSession() { 37 SessionOptions options; 38 (*options.config.mutable_device_count())["CPU"] = 2; 39 return std::unique_ptr<Session>(NewSession(options)); 40 } 41 42 class ExtractExampleParserConfigurationTest : public ::testing::Test { 43 protected: 44 void SetUp() override { 45 string proto_string; 46 string filename = 47 io::JoinPath(testing::TensorFlowSrcRoot(), 48 "core/example/testdata/parse_example_graph_def.pbtxt"); 49 ReadFileToStringOrDie(Env::Default(), filename, &proto_string); 50 protobuf::TextFormat::ParseFromString(proto_string, &graph_def_); 51 session_ = CreateSession(); 52 TF_CHECK_OK(session_->Create(graph_def_)); 53 } 54 55 NodeDef* parse_example_node() { 56 for (auto& node : *graph_def_.mutable_node()) { 57 if (node.name() == "ParseExample/ParseExample") { 58 return &node; 59 } 60 } 61 return nullptr; 62 } 63 64 GraphDef graph_def_; 65 std::unique_ptr<Session> session_; 66 }; 67 68 TEST_F(ExtractExampleParserConfigurationTest, OpNotFound) { 69 std::vector<FixedLenFeature> dense_vec; 70 std::vector<VarLenFeature> sparse_vec; 71 Status status = ExtractExampleParserConfiguration( 72 graph_def_, "BlarseExample/ParseExample", session_.get(), &dense_vec, 73 &sparse_vec); 74 75 EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); 76 } 77 78 TEST_F(ExtractExampleParserConfigurationTest, InconsistentAttrNsparse) { 79 std::vector<FixedLenFeature> dense_vec; 80 std::vector<VarLenFeature> sparse_vec; 81 82 NodeDef* node = parse_example_node(); 83 auto mutable_attr = node->mutable_attr(); 84 (*mutable_attr)["Nsparse"].set_i(3); 85 86 Status status = ExtractExampleParserConfiguration( 87 graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec, 88 &sparse_vec); 89 90 EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); 91 } 92 93 TEST_F(ExtractExampleParserConfigurationTest, InconsistentAttrNdense) { 94 std::vector<FixedLenFeature> dense_vec; 95 std::vector<VarLenFeature> sparse_vec; 96 97 NodeDef* node = parse_example_node(); 98 auto mutable_attr = node->mutable_attr(); 99 (*mutable_attr)["Ndense"].set_i(2); 100 101 Status status = ExtractExampleParserConfiguration( 102 graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec, 103 &sparse_vec); 104 105 EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); 106 } 107 108 TEST_F(ExtractExampleParserConfigurationTest, Basic) { 109 std::vector<FixedLenFeature> dense_vec; 110 std::vector<VarLenFeature> sparse_vec; 111 Status status = ExtractExampleParserConfiguration( 112 graph_def_, "ParseExample/ParseExample", session_.get(), &dense_vec, 113 &sparse_vec); 114 115 EXPECT_EQ(Status::OK(), status); 116 EXPECT_EQ(2, sparse_vec.size()); 117 EXPECT_EQ(3, dense_vec.size()); 118 119 EXPECT_EQ("sf0", sparse_vec[0].key); 120 EXPECT_EQ(DT_STRING, sparse_vec[0].dtype); 121 EXPECT_EQ("ParseExample/ParseExample:0", 122 sparse_vec[0].indices_output_tensor_name); 123 EXPECT_EQ("ParseExample/ParseExample:2", 124 sparse_vec[0].values_output_tensor_name); 125 EXPECT_EQ("ParseExample/ParseExample:4", 126 sparse_vec[0].shapes_output_tensor_name); 127 128 EXPECT_EQ("sf1", sparse_vec[1].key); 129 EXPECT_EQ(DT_STRING, sparse_vec[1].dtype); 130 EXPECT_EQ("ParseExample/ParseExample:1", 131 sparse_vec[1].indices_output_tensor_name); 132 EXPECT_EQ("ParseExample/ParseExample:3", 133 sparse_vec[1].values_output_tensor_name); 134 EXPECT_EQ("ParseExample/ParseExample:5", 135 sparse_vec[1].shapes_output_tensor_name); 136 137 EXPECT_EQ("x", dense_vec[0].key); 138 EXPECT_EQ(DT_FLOAT, dense_vec[0].dtype); 139 EXPECT_EQ("ParseExample/ParseExample:6", 140 dense_vec[0].values_output_tensor_name); 141 142 EXPECT_EQ("y", dense_vec[1].key); 143 EXPECT_EQ(DT_FLOAT, dense_vec[1].dtype); 144 EXPECT_EQ("ParseExample/ParseExample:7", 145 dense_vec[1].values_output_tensor_name); 146 147 EXPECT_EQ("z", dense_vec[2].key); 148 EXPECT_EQ(DT_FLOAT, dense_vec[2].dtype); 149 EXPECT_EQ("ParseExample/ParseExample:8", 150 dense_vec[2].values_output_tensor_name); 151 } 152 153 static const char kExampleParseConfigurationProto[] = R"( feature_map { 154 key: "x" 155 value { 156 fixed_len_feature { 157 dtype: DT_FLOAT 158 shape { 159 dim { 160 size: 1 161 } 162 } 163 default_value { 164 dtype: DT_FLOAT 165 tensor_shape { 166 dim { 167 size: 1 168 } 169 } 170 float_val: 33.0 171 } 172 values_output_tensor_name: "ParseExample/ParseExample:3" 173 } 174 } 175 } 176 feature_map { 177 key: "y" 178 value { 179 var_len_feature { 180 dtype: DT_STRING 181 values_output_tensor_name: "ParseExample/ParseExample:1" 182 indices_output_tensor_name: "ParseExample/ParseExample:0" 183 shapes_output_tensor_name: "ParseExample/ParseExample:2" 184 } 185 } 186 } 187 )"; 188 189 class ExampleParserConfigurationProtoToFeatureVectorsTest 190 : public ::testing::Test { 191 protected: 192 void SetUp() override { 193 CHECK(protobuf::TextFormat::ParseFromString(kExampleParseConfigurationProto, 194 &config_proto_)); 195 } 196 ExampleParserConfiguration config_proto_; 197 }; 198 199 TEST_F(ExampleParserConfigurationProtoToFeatureVectorsTest, Basic) { 200 std::vector<FixedLenFeature> fixed_len_features; 201 std::vector<VarLenFeature> var_len_features; 202 TF_ASSERT_OK(ExampleParserConfigurationProtoToFeatureVectors( 203 config_proto_, &fixed_len_features, &var_len_features)); 204 ASSERT_EQ(1, fixed_len_features.size()); 205 ASSERT_EQ(1, var_len_features.size()); 206 207 const FixedLenFeature& f = fixed_len_features[0]; 208 ASSERT_EQ(DT_FLOAT, f.dtype); 209 ASSERT_EQ("x", f.key); 210 ASSERT_EQ("ParseExample/ParseExample:3", f.values_output_tensor_name); 211 212 TensorShape expected_shape({1}); 213 ASSERT_EQ(expected_shape.dims(), f.shape.dims()); 214 ASSERT_EQ(1, f.shape.dim_size(0)); 215 216 Tensor expected_default(DT_FLOAT, TensorShape({1})); 217 test::FillIota<float>(&expected_default, 33.0); 218 test::ExpectTensorEqual<float>(expected_default, f.default_value); 219 220 const VarLenFeature& v = var_len_features[0]; 221 ASSERT_EQ(DT_STRING, v.dtype); 222 ASSERT_EQ("ParseExample/ParseExample:0", v.indices_output_tensor_name); 223 ASSERT_EQ("ParseExample/ParseExample:1", v.values_output_tensor_name); 224 ASSERT_EQ("ParseExample/ParseExample:2", v.shapes_output_tensor_name); 225 } 226 227 } // namespace 228 } // namespace tensorflow 229