1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/example/example_parser_configuration.h" 16 17 #include <vector> 18 19 #include "tensorflow/core/example/feature.pb_text.h" 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/framework/numeric_op.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.pb.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/strings/strcat.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/protobuf.h" 29 30 namespace tensorflow { 31 32 Status FindNodeIndexByName(const tensorflow::GraphDef& graph, 33 const string& node_name, int* node_idx) { 34 for (int i = 0; i < graph.node_size(); ++i) { 35 const auto& node = graph.node(i); 36 if (node.name() == node_name) { 37 *node_idx = i; 38 return Status::OK(); 39 } 40 } 41 return errors::InvalidArgument(node_name, " not found in GraphDef"); 42 } 43 44 Status ExtractExampleParserConfiguration( 45 const tensorflow::GraphDef& graph, const string& node_name, 46 tensorflow::Session* session, 47 std::vector<FixedLenFeature>* fixed_len_features, 48 std::vector<VarLenFeature>* var_len_features) { 49 int node_idx; 50 TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx)); 51 52 const auto& node = graph.node(node_idx); 53 if (node.op() != "ParseExample") { 54 return errors::InvalidArgument(node_name, " node is not a ParseExample op"); 55 } 56 57 auto& attr_map = node.attr(); 58 auto num_sparse = attr_map.at("Nsparse").i(); 59 auto num_dense = attr_map.at("Ndense").i(); 60 fixed_len_features->resize(num_dense); 61 var_len_features->resize(num_sparse); 62 63 auto tdense = attr_map.at("Tdense"); 64 auto dense_shapes = attr_map.at("dense_shapes"); 65 auto sparse_types = attr_map.at("sparse_types"); 66 67 // Consistency check attributes. 68 if (tdense.list().type_size() != num_dense) { 69 return errors::InvalidArgument("Node attr Tdense has ", 70 tdense.list().type_size(), 71 " elements != Ndense attr: ", num_dense); 72 } 73 74 if (dense_shapes.list().shape_size() != num_dense) { 75 return errors::InvalidArgument("Node attr dense_shapes has ", 76 dense_shapes.list().shape_size(), 77 " elements != Ndense attr: ", num_dense); 78 } 79 80 if (sparse_types.list().type_size() != num_sparse) { 81 return errors::InvalidArgument("Node attr sparse_types has ", 82 sparse_types.list().type_size(), 83 " elements != NSparse attr: ", num_sparse); 84 } 85 86 for (int i = 0; i < tdense.list().type_size(); ++i) { 87 (*fixed_len_features)[i].dtype = tdense.list().type(i); 88 // Convert TensorShapeProto to TensorShape. 89 (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i)); 90 } 91 92 for (int i = 0; i < sparse_types.list().type_size(); ++i) { 93 (*var_len_features)[i].dtype = sparse_types.list().type(i); 94 } 95 96 // We must fetch the configuration input tensors to the ParseExample op. 97 // Skipping index = 0, which is the serialized proto input. 98 std::vector<string> fetch_names(node.input_size() - 1); 99 for (int i = 1; i < node.input_size(); ++i) { 100 fetch_names[i - 1] = node.input(i); 101 } 102 103 std::vector<Tensor> op_input_tensors; 104 105 TF_RETURN_IF_ERROR(session->Run({}, // no_inputs, 106 fetch_names, {}, // no target_node_names, 107 &op_input_tensors)); 108 109 // The input tensors are laid out sequentially in a flat manner. 110 // Here are the various start offsets. 111 int sparse_keys_start = 1; 112 int dense_keys_start = sparse_keys_start + num_sparse; 113 int dense_defaults_start = dense_keys_start + num_dense; 114 115 for (int i = 0; i < num_sparse; ++i) { 116 int input_idx = sparse_keys_start + i; 117 (*var_len_features)[i].key = op_input_tensors[input_idx].scalar<string>()(); 118 } 119 120 for (int i = 0; i < num_dense; ++i) { 121 FixedLenFeature& config = (*fixed_len_features)[i]; 122 int dense_keys_offset = dense_keys_start + i; 123 config.key = op_input_tensors[dense_keys_offset].scalar<string>()(); 124 125 int defaults_offset = dense_defaults_start + i; 126 config.default_value = op_input_tensors[defaults_offset]; 127 } 128 129 // The output tensors are laid out sequentially in a flat manner. 130 // Here are the various start offsets. 131 int sparse_indices_output_start = 0; 132 int sparse_values_output_start = sparse_indices_output_start + num_sparse; 133 int sparse_shapes_output_start = sparse_values_output_start + num_sparse; 134 int dense_values_output_start = sparse_shapes_output_start + num_sparse; 135 136 string node_output_prefix = strings::StrCat(node_name, ":"); 137 138 for (int i = 0; i < num_sparse; ++i) { 139 VarLenFeature& config = (*var_len_features)[i]; 140 141 int indices_offset = sparse_indices_output_start + i; 142 config.indices_output_tensor_name = 143 strings::StrCat(node_output_prefix, indices_offset); 144 145 int values_offset = sparse_values_output_start + i; 146 config.values_output_tensor_name = 147 strings::StrCat(node_output_prefix, values_offset); 148 149 int shapes_offset = sparse_shapes_output_start + i; 150 config.shapes_output_tensor_name = 151 strings::StrCat(node_output_prefix, shapes_offset); 152 } 153 154 for (int i = 0; i < num_dense; ++i) { 155 int output_idx = dense_values_output_start + i; 156 (*fixed_len_features)[i].values_output_tensor_name = 157 strings::StrCat(node_output_prefix, output_idx); 158 } 159 return Status::OK(); 160 } 161 162 Status ExampleParserConfigurationProtoToFeatureVectors( 163 const ExampleParserConfiguration& config_proto, 164 std::vector<FixedLenFeature>* fixed_len_features, 165 std::vector<VarLenFeature>* var_len_features) { 166 const auto& feature_map = config_proto.feature_map(); 167 for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) { 168 string key = it->first; 169 const auto& config = it->second; 170 if (config.has_fixed_len_feature()) { 171 const auto& fixed_config = config.fixed_len_feature(); 172 FixedLenFeature f; 173 f.key = key; 174 f.dtype = fixed_config.dtype(); 175 f.shape = TensorShape(fixed_config.shape()); 176 Tensor default_value(f.dtype, f.shape); 177 if (!default_value.FromProto(fixed_config.default_value())) { 178 return errors::InvalidArgument( 179 "Invalid default_value in config proto ", 180 fixed_config.default_value().DebugString()); 181 } 182 f.default_value = default_value; 183 f.values_output_tensor_name = fixed_config.values_output_tensor_name(); 184 fixed_len_features->push_back(f); 185 } else { 186 const auto& var_len_config = config.var_len_feature(); 187 VarLenFeature v; 188 v.key = key; 189 v.dtype = var_len_config.dtype(); 190 v.values_output_tensor_name = var_len_config.values_output_tensor_name(); 191 v.indices_output_tensor_name = 192 var_len_config.indices_output_tensor_name(); 193 v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name(); 194 var_len_features->push_back(v); 195 } 196 } 197 return Status::OK(); 198 } 199 200 } // namespace tensorflow 201