Home | History | Annotate | Download | only in util
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Extract parse_example op configuration to a proto."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.example import example_parser_configuration_pb2
     22 from tensorflow.python.framework import tensor_shape
     23 from tensorflow.python.framework import tensor_util
     24 
     25 
     26 def extract_example_parser_configuration(parse_example_op, sess):
     27   """Returns an ExampleParserConfig proto.
     28 
     29   Args:
     30     parse_example_op: A ParseExample `Operation`
     31     sess: A tf.Session needed to obtain some configuration values.
     32   Returns:
     33     A ExampleParserConfig proto.
     34 
     35   Raises:
     36     ValueError: If attributes are inconsistent.
     37   """
     38   config = example_parser_configuration_pb2.ExampleParserConfiguration()
     39 
     40   num_sparse = parse_example_op.get_attr("Nsparse")
     41   num_dense = parse_example_op.get_attr("Ndense")
     42   total_features = num_dense + num_sparse
     43 
     44   sparse_types = parse_example_op.get_attr("sparse_types")
     45   dense_types = parse_example_op.get_attr("Tdense")
     46   dense_shapes = parse_example_op.get_attr("dense_shapes")
     47 
     48   if len(sparse_types) != num_sparse:
     49     raise ValueError("len(sparse_types) attribute does not match "
     50                      "Nsparse attribute (%d vs %d)" %
     51                      (len(sparse_types), num_sparse))
     52 
     53   if len(dense_types) != num_dense:
     54     raise ValueError("len(dense_types) attribute does not match "
     55                      "Ndense attribute (%d vs %d)" %
     56                      (len(dense_types), num_dense))
     57 
     58   if len(dense_shapes) != num_dense:
     59     raise ValueError("len(dense_shapes) attribute does not match "
     60                      "Ndense attribute (%d vs %d)" %
     61                      (len(dense_shapes), num_dense))
     62 
     63   # Skip over the serialized input, and the names input.
     64   fetch_list = parse_example_op.inputs[2:]
     65 
     66   # Fetch total_features key names and num_dense default values.
     67   if len(fetch_list) != (total_features + num_dense):
     68     raise ValueError("len(fetch_list) does not match total features + "
     69                      "num_dense (%d vs %d)" %
     70                      (len(fetch_list), (total_features + num_dense)))
     71 
     72   fetched = sess.run(fetch_list)
     73 
     74   if len(fetched) != len(fetch_list):
     75     raise ValueError("len(fetched) does not match len(fetch_list) "
     76                      "(%d vs %d)" % (len(fetched), len(fetch_list)))
     77 
     78   # Fetch indices.
     79   sparse_keys_start = 0
     80   dense_keys_start = sparse_keys_start + num_sparse
     81   dense_def_start = dense_keys_start + num_dense
     82 
     83   # Output tensor indices.
     84   sparse_indices_start = 0
     85   sparse_values_start = num_sparse
     86   sparse_shapes_start = sparse_values_start + num_sparse
     87   dense_values_start = sparse_shapes_start + num_sparse
     88 
     89   # Dense features.
     90   for i in range(num_dense):
     91     key = fetched[dense_keys_start + i]
     92     feature_config = config.feature_map[key]
     93     # Convert the default value numpy array fetched from the session run
     94     # into a TensorProto.
     95     fixed_config = feature_config.fixed_len_feature
     96 
     97     fixed_config.default_value.CopyFrom(
     98         tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
     99     # Convert the shape from the attributes
    100     # into a TensorShapeProto.
    101     fixed_config.shape.CopyFrom(
    102         tensor_shape.TensorShape(dense_shapes[i]).as_proto())
    103 
    104     fixed_config.dtype = int(dense_types[i])
    105     # Get the output tensor name.
    106     fixed_config.values_output_tensor_name = parse_example_op.outputs[
    107         dense_values_start + i].name
    108 
    109   # Sparse features.
    110   for i in range(num_sparse):
    111     key = fetched[sparse_keys_start + i]
    112     feature_config = config.feature_map[key]
    113     var_len_feature = feature_config.var_len_feature
    114     var_len_feature.dtype = int(sparse_types[i])
    115     var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
    116         sparse_indices_start + i].name
    117     var_len_feature.values_output_tensor_name = parse_example_op.outputs[
    118         sparse_values_start + i].name
    119     var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
    120         sparse_shapes_start + i].name
    121 
    122   return config
    123