Home | History | Annotate | Download | only in example
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/example/example_parser_configuration.h"
     16 
     17 #include <vector>
     18 
     19 #include "tensorflow/core/example/feature.pb_text.h"
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.pb.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #include "tensorflow/core/platform/protobuf.h"
     29 
     30 namespace tensorflow {
     31 
     32 Status FindNodeIndexByName(const tensorflow::GraphDef& graph,
     33                            const string& node_name, int* node_idx) {
     34   for (int i = 0; i < graph.node_size(); ++i) {
     35     const auto& node = graph.node(i);
     36     if (node.name() == node_name) {
     37       *node_idx = i;
     38       return Status::OK();
     39     }
     40   }
     41   return errors::InvalidArgument(node_name, " not found in GraphDef");
     42 }
     43 
     44 Status ExtractExampleParserConfiguration(
     45     const tensorflow::GraphDef& graph, const string& node_name,
     46     tensorflow::Session* session,
     47     std::vector<FixedLenFeature>* fixed_len_features,
     48     std::vector<VarLenFeature>* var_len_features) {
     49   int node_idx;
     50   TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx));
     51 
     52   const auto& node = graph.node(node_idx);
     53   if (node.op() != "ParseExample") {
     54     return errors::InvalidArgument(node_name, " node is not a ParseExample op");
     55   }
     56 
     57   auto& attr_map = node.attr();
     58   auto num_sparse = attr_map.at("Nsparse").i();
     59   auto num_dense = attr_map.at("Ndense").i();
     60   fixed_len_features->resize(num_dense);
     61   var_len_features->resize(num_sparse);
     62 
     63   auto tdense = attr_map.at("Tdense");
     64   auto dense_shapes = attr_map.at("dense_shapes");
     65   auto sparse_types = attr_map.at("sparse_types");
     66 
     67   // Consistency check attributes.
     68   if (tdense.list().type_size() != num_dense) {
     69     return errors::InvalidArgument("Node attr Tdense has ",
     70                                    tdense.list().type_size(),
     71                                    " elements != Ndense attr: ", num_dense);
     72   }
     73 
     74   if (dense_shapes.list().shape_size() != num_dense) {
     75     return errors::InvalidArgument("Node attr dense_shapes has ",
     76                                    dense_shapes.list().shape_size(),
     77                                    " elements != Ndense attr: ", num_dense);
     78   }
     79 
     80   if (sparse_types.list().type_size() != num_sparse) {
     81     return errors::InvalidArgument("Node attr sparse_types has ",
     82                                    sparse_types.list().type_size(),
     83                                    " elements != NSparse attr: ", num_sparse);
     84   }
     85 
     86   for (int i = 0; i < tdense.list().type_size(); ++i) {
     87     (*fixed_len_features)[i].dtype = tdense.list().type(i);
     88     // Convert TensorShapeProto to TensorShape.
     89     (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i));
     90   }
     91 
     92   for (int i = 0; i < sparse_types.list().type_size(); ++i) {
     93     (*var_len_features)[i].dtype = sparse_types.list().type(i);
     94   }
     95 
     96   // We must fetch the configuration input tensors to the ParseExample op.
     97   // Skipping index = 0, which is the serialized proto input.
     98   std::vector<string> fetch_names(node.input_size() - 1);
     99   for (int i = 1; i < node.input_size(); ++i) {
    100     fetch_names[i - 1] = node.input(i);
    101   }
    102 
    103   std::vector<Tensor> op_input_tensors;
    104 
    105   TF_RETURN_IF_ERROR(session->Run({},               // no_inputs,
    106                                   fetch_names, {},  // no target_node_names,
    107                                   &op_input_tensors));
    108 
    109   // The input tensors are laid out sequentially in a flat manner.
    110   // Here are the various start offsets.
    111   int sparse_keys_start = 1;
    112   int dense_keys_start = sparse_keys_start + num_sparse;
    113   int dense_defaults_start = dense_keys_start + num_dense;
    114 
    115   for (int i = 0; i < num_sparse; ++i) {
    116     int input_idx = sparse_keys_start + i;
    117     (*var_len_features)[i].key = op_input_tensors[input_idx].scalar<string>()();
    118   }
    119 
    120   for (int i = 0; i < num_dense; ++i) {
    121     FixedLenFeature& config = (*fixed_len_features)[i];
    122     int dense_keys_offset = dense_keys_start + i;
    123     config.key = op_input_tensors[dense_keys_offset].scalar<string>()();
    124 
    125     int defaults_offset = dense_defaults_start + i;
    126     config.default_value = op_input_tensors[defaults_offset];
    127   }
    128 
    129   // The output tensors are laid out sequentially in a flat manner.
    130   // Here are the various start offsets.
    131   int sparse_indices_output_start = 0;
    132   int sparse_values_output_start = sparse_indices_output_start + num_sparse;
    133   int sparse_shapes_output_start = sparse_values_output_start + num_sparse;
    134   int dense_values_output_start = sparse_shapes_output_start + num_sparse;
    135 
    136   string node_output_prefix = strings::StrCat(node_name, ":");
    137 
    138   for (int i = 0; i < num_sparse; ++i) {
    139     VarLenFeature& config = (*var_len_features)[i];
    140 
    141     int indices_offset = sparse_indices_output_start + i;
    142     config.indices_output_tensor_name =
    143         strings::StrCat(node_output_prefix, indices_offset);
    144 
    145     int values_offset = sparse_values_output_start + i;
    146     config.values_output_tensor_name =
    147         strings::StrCat(node_output_prefix, values_offset);
    148 
    149     int shapes_offset = sparse_shapes_output_start + i;
    150     config.shapes_output_tensor_name =
    151         strings::StrCat(node_output_prefix, shapes_offset);
    152   }
    153 
    154   for (int i = 0; i < num_dense; ++i) {
    155     int output_idx = dense_values_output_start + i;
    156     (*fixed_len_features)[i].values_output_tensor_name =
    157         strings::StrCat(node_output_prefix, output_idx);
    158   }
    159   return Status::OK();
    160 }
    161 
    162 Status ExampleParserConfigurationProtoToFeatureVectors(
    163     const ExampleParserConfiguration& config_proto,
    164     std::vector<FixedLenFeature>* fixed_len_features,
    165     std::vector<VarLenFeature>* var_len_features) {
    166   const auto& feature_map = config_proto.feature_map();
    167   for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) {
    168     string key = it->first;
    169     const auto& config = it->second;
    170     if (config.has_fixed_len_feature()) {
    171       const auto& fixed_config = config.fixed_len_feature();
    172       FixedLenFeature f;
    173       f.key = key;
    174       f.dtype = fixed_config.dtype();
    175       f.shape = TensorShape(fixed_config.shape());
    176       Tensor default_value(f.dtype, f.shape);
    177       if (!default_value.FromProto(fixed_config.default_value())) {
    178         return errors::InvalidArgument(
    179             "Invalid default_value in config proto ",
    180             fixed_config.default_value().DebugString());
    181       }
    182       f.default_value = default_value;
    183       f.values_output_tensor_name = fixed_config.values_output_tensor_name();
    184       fixed_len_features->push_back(f);
    185     } else {
    186       const auto& var_len_config = config.var_len_feature();
    187       VarLenFeature v;
    188       v.key = key;
    189       v.dtype = var_len_config.dtype();
    190       v.values_output_tensor_name = var_len_config.values_output_tensor_name();
    191       v.indices_output_tensor_name =
    192           var_len_config.indices_output_tensor_name();
    193       v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name();
    194       var_len_features->push_back(v);
    195     }
    196   }
    197   return Status::OK();
    198 }
    199 
    200 }  // namespace tensorflow
    201