Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 #==============================================================================
     15 """Lookup operations."""
     16 # pylint: disable=g-bad-name
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import functools
     23 import six
     24 
     25 from tensorflow.python.compat import compat as fwd_compat
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import sparse_tensor
     31 from tensorflow.python.framework import tensor_shape
     32 from tensorflow.python.framework import tensor_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import control_flow_ops
     35 from tensorflow.python.ops import gen_lookup_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import string_ops
     38 # go/tf-wildcard-import
     39 # pylint: disable=wildcard-import
     40 from tensorflow.python.ops.gen_lookup_ops import *
     41 from tensorflow.python.training.saver import BaseSaverBuilder
     42 # pylint: enable=wildcard-import
     43 from tensorflow.python.training.tracking import base as trackable_base
     44 from tensorflow.python.training.tracking import tracking as trackable
     45 from tensorflow.python.util import compat
     46 from tensorflow.python.util.deprecation import deprecated
     47 from tensorflow.python.util.tf_export import tf_export
     48 
     49 
     50 @tf_export(v1=["initialize_all_tables"])
     51 @deprecated(None, "Use `tf.tables_initializer` instead.")
     52 def initialize_all_tables(name="init_all_tables"):
     53   """Returns an Op that initializes all tables of the default graph.
     54 
     55   Args:
     56     name: Optional name for the initialization op.
     57 
     58   Returns:
     59     An Op that initializes all tables.  Note that if there are
     60     not tables the returned Op is a NoOp.
     61   """
     62   return tables_initializer(name)
     63 
     64 
     65 @tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
     66 def tables_initializer(name="init_all_tables"):
     67   """Returns an Op that initializes all tables of the default graph.
     68 
     69   See the [Low Level Intro](https://www.tensorflow.org/guide/low_level_intro#feature_columns)
     70   guide, for an example of usage.
     71 
     72   Args:
     73     name: Optional name for the initialization op.
     74 
     75   Returns:
     76     An Op that initializes all tables.  Note that if there are
     77     not tables the returned Op is a NoOp.
     78   """
     79   initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
     80   if initializers:
     81     return control_flow_ops.group(*initializers, name=name)
     82   return control_flow_ops.no_op(name=name)
     83 
     84 
     85 def _check_table_dtypes(table, key_dtype, value_dtype):
     86   """Check that the given key_dtype and value_dtype matches the table dtypes.
     87 
     88   Args:
     89     table: The table to check types against to.
     90     key_dtype: The key data type to check.
     91     value_dtype: The value data type to check.
     92 
     93   Raises:
     94     TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
     95       types.
     96   """
     97   if key_dtype.base_dtype != table.key_dtype:
     98     raise TypeError("Invalid key dtype, expected %s but got %s." %
     99                     (table.key_dtype, key_dtype))
    100   if value_dtype.base_dtype != table.value_dtype:
    101     raise TypeError("Invalid value dtype, expected %s but got %s." %
    102                     (table.value_dtype, value_dtype))
    103 
    104 
    105 class LookupInterface(trackable.TrackableResource):
    106   """Represent a lookup table that persists across different steps."""
    107 
    108   def __init__(self, key_dtype, value_dtype):
    109     """Construct a lookup table interface.
    110 
    111     Args:
    112       key_dtype: The table key type.
    113       value_dtype: The table value type.
    114     """
    115     self._key_dtype = dtypes.as_dtype(key_dtype)
    116     self._value_dtype = dtypes.as_dtype(value_dtype)
    117     super(LookupInterface, self).__init__()
    118 
    119   def _create_resource(self):
    120     raise NotImplementedError
    121 
    122   @property
    123   def key_dtype(self):
    124     """The table key dtype."""
    125     return self._key_dtype
    126 
    127   @property
    128   def value_dtype(self):
    129     """The table value dtype."""
    130     return self._value_dtype
    131 
    132   @property
    133   def name(self):
    134     """The name of the table."""
    135     return NotImplementedError
    136 
    137   def size(self, name=None):
    138     """Compute the number of elements in this table."""
    139     raise NotImplementedError
    140 
    141   def lookup(self, keys, name=None):
    142     """Looks up `keys` in a table, outputs the corresponding values."""
    143     raise NotImplementedError
    144 
    145 
    146 class InitializableLookupTableBase(LookupInterface):
    147   """Initializable lookup table interface.
    148 
    149   An initializable lookup tables persist across different steps.
    150   """
    151 
    152   def __init__(self, default_value, initializer):
    153     """Construct a table object from a table reference.
    154 
    155     If requires a table initializer object (subclass of `TableInitializerBase`).
    156     It provides the table key and value types, as well as the op to initialize
    157     the table. The caller is responsible to execute the initialization op.
    158 
    159     Args:
    160       default_value: The value to use if a key is missing in the table.
    161       initializer: The table initializer to use.
    162     """
    163     super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
    164                                                        initializer.value_dtype)
    165     self._default_value = ops.convert_to_tensor(
    166         default_value, dtype=self._value_dtype)
    167     self._default_value.get_shape().merge_with(tensor_shape.scalar())
    168     if isinstance(initializer, trackable_base.Trackable):
    169       self._initializer = self._track_trackable(
    170           initializer, "_initializer")
    171     with ops.init_scope():
    172       self._resource_handle = self._create_resource()
    173       self._init_op = self._initialize()
    174 
    175   def _initialize(self):
    176     return self._initializer.initialize(self)
    177 
    178   @property
    179   def default_value(self):
    180     """The default value of the table."""
    181     return self._default_value
    182 
    183   def size(self, name=None):
    184     """Compute the number of elements in this table.
    185 
    186     Args:
    187       name: A name for the operation (optional).
    188 
    189     Returns:
    190       A scalar tensor containing the number of elements in this table.
    191     """
    192     with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
    193       return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
    194 
    195   def lookup(self, keys, name=None):
    196     """Looks up `keys` in a table, outputs the corresponding values.
    197 
    198     The `default_value` is used for keys not present in the table.
    199 
    200     Args:
    201       keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
    202       name: A name for the operation (optional).
    203 
    204     Returns:
    205       A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
    206 
    207     Raises:
    208       TypeError: when `keys` or `default_value` doesn't match the table data
    209         types.
    210     """
    211     key_tensor = keys
    212     if isinstance(keys, sparse_tensor.SparseTensor):
    213       key_tensor = keys.values
    214 
    215     if keys.dtype.base_dtype != self._key_dtype:
    216       raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
    217                       (self._key_dtype, keys.dtype))
    218 
    219     with ops.name_scope(
    220         name, "%s_Lookup" % self.name,
    221         (self.resource_handle, key_tensor, self._default_value)):
    222       values = gen_lookup_ops.lookup_table_find_v2(
    223           self.resource_handle, key_tensor, self._default_value)
    224 
    225     values.set_shape(key_tensor.get_shape())
    226     if isinstance(keys, sparse_tensor.SparseTensor):
    227       return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
    228     else:
    229       return values
    230 
    231 
    232 class InitializableLookupTableBaseV1(InitializableLookupTableBase):
    233 
    234   @property
    235   def initializer(self):
    236     return self._init_op
    237 
    238 
    239 @tf_export("lookup.StaticHashTable", v1=[])
    240 class StaticHashTable(InitializableLookupTableBase):
    241   """A generic hash table implementation.
    242 
    243   Example usage:
    244 
    245   ```python
    246   table = tf.lookup.StaticHashTable(
    247       tf.KeyValueTensorInitializer(keys, values), -1)
    248   out = table.lookup(input_tensor)
    249   table.init.run()
    250   print(out.eval())
    251   ```
    252   """
    253 
    254   def __init__(self, initializer, default_value, name=None):
    255     """Creates a non-initialized `HashTable` object.
    256 
    257     Creates a table, the type of its keys and values are specified by the
    258     initializer.
    259     Before using the table you will have to initialize it. After initialization
    260     the table will be immutable.
    261 
    262     Args:
    263       initializer: The table initializer to use. See `HashTable` kernel for
    264         supported key and value types.
    265       default_value: The value to use if a key is missing in the table.
    266       name: A name for the operation (optional).
    267 
    268     Returns:
    269       A `HashTable` object.
    270     """
    271     self._initializer = initializer
    272     self._default_value = default_value
    273     self._shared_name = self._initializer._shared_name  # pylint: disable=protected-access
    274     self._name = name or "hash_table"
    275     self._table_name = None
    276     super(StaticHashTable, self).__init__(default_value, initializer)
    277     self._value_shape = self._default_value.get_shape()
    278 
    279   def _create_resource(self):
    280     table_ref = gen_lookup_ops.hash_table_v2(
    281         shared_name=self._shared_name,
    282         key_dtype=self._initializer.key_dtype,
    283         value_dtype=self._initializer.value_dtype,
    284         name=self._name)
    285     if context.executing_eagerly():
    286       self._table_name = None
    287     else:
    288       self._table_name = table_ref.op.name.split("/")[-1]
    289     return table_ref
    290 
    291   @property
    292   def name(self):
    293     return self._table_name
    294 
    295   def export(self, name=None):
    296     """Returns tensors of all keys and values in the table.
    297 
    298     Args:
    299       name: A name for the operation (optional).
    300 
    301     Returns:
    302       A pair of tensors with the first tensor containing all keys and the
    303         second tensors containing all values in the table.
    304     """
    305     with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
    306       exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
    307           self.resource_handle, self._key_dtype, self._value_dtype)
    308 
    309     exported_values.set_shape(exported_keys.get_shape().concatenate(
    310         self._value_shape))
    311     return exported_keys, exported_values
    312 
    313 
    314 @tf_export(v1=["lookup.StaticHashTable"])
    315 class StaticHashTableV1(StaticHashTable):
    316 
    317   @property
    318   def initializer(self):
    319     return self._init_op
    320 
    321 
    322 # For backwards compatibility. This will be removed in TF 2.0.
    323 class HashTable(StaticHashTableV1):
    324 
    325   @property
    326   def init(self):
    327     return self.initializer
    328 
    329 
    330 class TableInitializerBase(trackable_base.Trackable):
    331   """Base class for lookup table initializers."""
    332 
    333   def __init__(self, key_dtype, value_dtype):
    334     """Construct a table initializer object.
    335 
    336     Args:
    337       key_dtype: Type of the table keys.
    338       value_dtype: Type of the table values.
    339     """
    340     self._key_dtype = dtypes.as_dtype(key_dtype)
    341     self._value_dtype = dtypes.as_dtype(value_dtype)
    342 
    343   @property
    344   def key_dtype(self):
    345     """The expected table key dtype."""
    346     return self._key_dtype
    347 
    348   @property
    349   def value_dtype(self):
    350     """The expected table value dtype."""
    351     return self._value_dtype
    352 
    353   def initialize(self, table):
    354     """Returns the table initialization op."""
    355     raise NotImplementedError
    356 
    357   @property
    358   def _shared_name(self):
    359     """Returns a shared name to be used by the table."""
    360     shared_name = ""
    361     if context.executing_eagerly():
    362       # Ensure a unique name when eager execution is enabled to avoid spurious
    363       # sharing issues.
    364       # TODO(rohanj): Use context.shared_name() instead.
    365       shared_name += str(ops.uid())
    366     return shared_name
    367 
    368 
    369 @tf_export("lookup.KeyValueTensorInitializer")
    370 class KeyValueTensorInitializer(TableInitializerBase):
    371   """Table initializers given `keys` and `values` tensors."""
    372 
    373   def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
    374     """Constructs a table initializer object based on keys and values tensors.
    375 
    376     Args:
    377       keys: The tensor for the keys.
    378       values: The tensor for the values.
    379       key_dtype: The `keys` data type. Used when `keys` is a python array.
    380       value_dtype: The `values` data type. Used when `values` is a python array.
    381       name: A name for the operation (optional).
    382     """
    383     with ops.init_scope():
    384       self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
    385       self._values = ops.convert_to_tensor(
    386           values, dtype=value_dtype, name="values")
    387     self._name = name if name is not None else "key_value_init"
    388     if context.executing_eagerly():
    389       # Ensure a unique name when eager execution is enabled to avoid spurious
    390       # sharing issues.
    391       # TODO(rohanj): Use context.shared_name() instead.
    392       self._name += str(ops.uid())
    393 
    394     super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
    395                                                     self._values.dtype)
    396 
    397   def initialize(self, table):
    398     """Initializes the given `table` with `keys` and `values` tensors.
    399 
    400     Args:
    401       table: The table to initialize.
    402 
    403     Returns:
    404       The operation that initializes the table.
    405 
    406     Raises:
    407       TypeError: when the keys and values data types do not match the table
    408       key and value data types.
    409     """
    410     _check_table_dtypes(table, self._keys.dtype, self._values.dtype)
    411     with ops.name_scope(
    412         self._name, values=(table.resource_handle, self._keys, self._values)):
    413       if fwd_compat.forward_compatible(2018, 9, 19):
    414         init_op = gen_lookup_ops.lookup_table_import_v2(
    415             table.resource_handle, self._keys, self._values)
    416       else:
    417         # To maintain forward compatibiltiy, use the old implementation.
    418         init_op = gen_lookup_ops.initialize_table_v2(table.resource_handle,
    419                                                      self._keys, self._values)
    420     ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    421     return init_op
    422 
    423 
    424 class TextFileIndex(object):
    425   WHOLE_LINE = -2
    426   LINE_NUMBER = -1
    427 
    428 
    429 @tf_export("lookup.TextFileInitializer")
    430 class TextFileInitializer(TableInitializerBase):
    431   """Table initializers from a text file.
    432 
    433   This initializer assigns one entry in the table for each line in the file.
    434 
    435   The key and value type of the table to initialize is given by `key_dtype` and
    436   `value_dtype`.
    437 
    438   The key and value content to get from each line is specified by
    439   the `key_index` and `value_index`.
    440 
    441   * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
    442     expects data type int64.
    443   * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
    444     type string.
    445   * A value `>=0` means use the index (starting at zero) of the split line based
    446       on `delimiter`.
    447 
    448   For example if we have a file with the following content:
    449 
    450   ```
    451   emerson 10
    452   lake 20
    453   palmer 30
    454   ```
    455 
    456   The following snippet initializes a table with the first column as keys and
    457   second column as values:
    458 
    459   * `emerson -> 10`
    460   * `lake -> 20`
    461   * `palmer -> 30`
    462 
    463   ```python
    464   table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
    465       "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
    466   ...
    467   table.init.run()
    468   ```
    469 
    470   Similarly to initialize the whole line as keys and the line number as values.
    471 
    472   * `emerson 10 -> 0`
    473   * `lake 20 -> 1`
    474   * `palmer 30 -> 2`
    475 
    476   ```python
    477   table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
    478       "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
    479       tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
    480   ...
    481   table.init.run()
    482   ```
    483   """
    484 
    485   def __init__(self,
    486                filename,
    487                key_dtype,
    488                key_index,
    489                value_dtype,
    490                value_index,
    491                vocab_size=None,
    492                delimiter="\t",
    493                name=None):
    494     """Constructs a table initializer object to populate from a text file.
    495 
    496     It generates one key-value pair per line. The type of table key and
    497     value are specified by `key_dtype` and `value_dtype`, respectively.
    498     Similarly the content of the key and value are specified by the key_index
    499     and value_index.
    500 
    501     - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
    502       expects data type int64.
    503     - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
    504       type string.
    505     - A value >=0 means use the index (starting at zero) of the split line based
    506       on `delimiter`.
    507 
    508     Args:
    509       filename: The filename of the text file to be used for initialization.
    510         The path must be accessible from wherever the graph is initialized
    511         (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
    512       key_dtype: The `key` data type.
    513       key_index: the index that represents information of a line to get the
    514         table 'key' values from.
    515       value_dtype: The `value` data type.
    516       value_index: the index that represents information of a line to get the
    517         table 'value' values from.'
    518       vocab_size: The number of elements in the file, if known.
    519       delimiter: The delimiter to separate fields in a line.
    520       name: A name for the operation (optional).
    521 
    522     Raises:
    523       ValueError: when the filename is empty, or when the table key and value
    524       data types do not match the expected data types.
    525     """
    526     if not isinstance(filename, ops.Tensor) and not filename:
    527       raise ValueError("Filename required for %s." % name)
    528 
    529     self._filename_arg = filename
    530     key_dtype = dtypes.as_dtype(key_dtype)
    531     value_dtype = dtypes.as_dtype(value_dtype)
    532 
    533     if key_index < -2:
    534       raise ValueError("Invalid key index %s." % (key_index))
    535 
    536     if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
    537       raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
    538                        (dtypes.int64, key_dtype))
    539     if ((key_index == TextFileIndex.WHOLE_LINE) and
    540         (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
    541       raise ValueError(
    542           "Signature mismatch. Keys must be integer or string, got %s." %
    543           key_dtype)
    544     if value_index < -2:
    545       raise ValueError("Invalid value index %s." % (value_index))
    546 
    547     if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
    548       raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
    549                        (dtypes.int64, value_dtype))
    550     if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
    551       raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
    552                        (dtypes.string, value_dtype))
    553 
    554     if (vocab_size is not None) and (vocab_size <= 0):
    555       raise ValueError("Invalid vocab_size %s." % vocab_size)
    556 
    557     self._key_index = key_index
    558     self._value_index = value_index
    559     self._vocab_size = vocab_size
    560     self._delimiter = delimiter
    561     self._name = name
    562     self._filename = self._track_trackable(
    563         trackable.TrackableAsset(filename),
    564         "_filename")
    565 
    566     super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
    567 
    568   def initialize(self, table):
    569     """Initializes the table from a text file.
    570 
    571     Args:
    572       table: The table to be initialized.
    573 
    574     Returns:
    575       The operation that initializes the table.
    576 
    577     Raises:
    578       TypeError: when the keys and values data types do not match the table
    579       key and value data types.
    580     """
    581     _check_table_dtypes(table, self.key_dtype, self.value_dtype)
    582     with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
    583       filename = ops.convert_to_tensor(
    584           self._filename, dtypes.string, name="asset_filepath")
    585       init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
    586           table.resource_handle, filename, self._key_index, self._value_index,
    587           -1 if self._vocab_size is None else self._vocab_size, self._delimiter)
    588     ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    589     # If the filename tensor is anything other than a string constant (e.g.,
    590     # if it is a placeholder) then it does not make sense to track it as an
    591     # asset.
    592     if not context.executing_eagerly() and constant_op.is_constant(filename):
    593       ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
    594     return init_op
    595 
    596   @property
    597   def _shared_name(self):
    598     if self._vocab_size:
    599       # Keep the shared_name:
    600       # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
    601       shared_name = "hash_table_%s_%d_%s_%s" % (
    602           self._filename_arg, self._vocab_size, self._key_index,
    603           self._value_index)
    604     else:
    605       # Keep the shared_name
    606       # <table_type>_<filename>_<key_index>_<value_index>
    607       shared_name = "hash_table_%s_%s_%s" % (self._filename_arg,
    608                                              self._key_index, self._value_index)
    609     return shared_name
    610 
    611 
    612 class TextFileStringTableInitializer(TextFileInitializer):
    613   """Table initializer for `int64` IDs to string tables from a text file."""
    614 
    615   def __init__(self,
    616                filename,
    617                key_column_index=TextFileIndex.LINE_NUMBER,
    618                value_column_index=TextFileIndex.WHOLE_LINE,
    619                vocab_size=None,
    620                delimiter="\t",
    621                name="text_file_string_table_init"):
    622     """Constructs an initializer for an id-to-string table from a text file.
    623 
    624     It populates a table that its key and value types are int64 and string,
    625     respectively. It generates one key-value pair per line.
    626     The content of the key and value are specified by `key_column_index`
    627     and `value_column_index`.
    628 
    629     - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
    630       expects data type int64.
    631     - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
    632       type string.
    633     - A value >=0 means use the index (starting at zero) of the split line based
    634       on `delimiter`.
    635 
    636     Args:
    637       filename: The filename of the text file to be used for initialization.
    638         The path must be accessible from wherever the graph is initialized
    639         (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
    640       key_column_index: The column index from the text file to get the keys
    641         from. The default is to use the line number, starting from zero.
    642       value_column_index: The column index from the text file to get the
    643         values from. The default is to use the whole line content.
    644       vocab_size: The number of elements in the file, if known.
    645       delimiter: The delimiter to separate fields in a line.
    646       name: Optional name for the op.
    647 
    648     Raises:
    649       TypeError: when the filename is empty, or when the table key and value
    650       data types do not match the expected data types.
    651     """
    652     super(TextFileStringTableInitializer, self).__init__(
    653         filename,
    654         dtypes.int64,
    655         key_column_index,
    656         dtypes.string,
    657         value_column_index,
    658         vocab_size=vocab_size,
    659         delimiter=delimiter,
    660         name=name)
    661 
    662 
    663 class TextFileIdTableInitializer(TextFileInitializer):
    664   """Table initializer for string to `int64` IDs tables from a text file."""
    665 
    666   def __init__(self,
    667                filename,
    668                key_column_index=TextFileIndex.WHOLE_LINE,
    669                value_column_index=TextFileIndex.LINE_NUMBER,
    670                vocab_size=None,
    671                delimiter="\t",
    672                name="text_file_id_table_init",
    673                key_dtype=dtypes.string):
    674     """Constructs an initializer for an string-to-id table from a text file.
    675 
    676     It populates a table that its key and value types are string and int64,
    677     respectively. It generates one key-value pair per line.
    678     The content of the key and value are specified by the key_index
    679     and value_index.
    680 
    681     - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
    682       expects data type int64.
    683     - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
    684       type string.
    685     - A value >=0 means use the index (starting at zero) of the split line based
    686       on `delimiter`.
    687 
    688     Args:
    689       filename: The filename of the text file to be used for initialization.
    690         The path must be accessible from wherever the graph is initialized
    691         (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
    692       key_column_index: The column index from the text file to get the `key`
    693         values from. The default is to use the whole line content.
    694       value_column_index: The column index from the text file to get the `value`
    695         values from. The default is to use the line number, starting from zero.
    696       vocab_size: The number of elements in the file, if known.
    697       delimiter: The delimiter to separate fields in a line.
    698       name: Optional name for the op.
    699       key_dtype: The `key` data type.
    700 
    701     Raises:
    702       TypeError: when the filename is empty, or when the table key and value
    703       data types do not match the expected data types.
    704     """
    705     super(TextFileIdTableInitializer, self).__init__(
    706         filename,
    707         key_dtype,
    708         key_column_index,
    709         dtypes.int64,
    710         value_column_index,
    711         vocab_size=vocab_size,
    712         delimiter=delimiter,
    713         name=name)
    714 
    715 
    716 class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
    717   """A structure for the spec of the hashing function to use for hash buckets.
    718 
    719   `hasher` is the name of the hashing function to use (eg. "fasthash",
    720   "stronghash").
    721   `key` is optional and specify the key to use for the hash function if
    722   supported, currently only used by a strong hash.
    723 
    724   Fields:
    725     hasher: The hasher name to use.
    726     key: The key to be used by the hashing function, if required.
    727   """
    728   __slots__ = ()
    729 
    730 
    731 FastHashSpec = HasherSpec("fasthash", None)  # pylint: disable=invalid-name
    732 
    733 
    734 class StrongHashSpec(HasherSpec):
    735   """A structure to specify a key of the strong keyed hash spec.
    736 
    737   The strong hash requires a `key`, which is a list of 2 unsigned integer
    738   numbers. These should be non-zero; random numbers generated from random.org
    739   would be a fine choice.
    740 
    741   Fields:
    742     key: The key to be used by the keyed hashing function.
    743   """
    744   __slots__ = ()
    745 
    746   def __new__(cls, key):
    747     if len(key) != 2:
    748       raise ValueError("key must have size 2, got %s." % len(key))
    749 
    750     if not isinstance(key[0], compat.integral_types) or not isinstance(
    751         key[1], compat.integral_types):
    752       raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
    753 
    754     return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
    755 
    756 
    757 def _as_string(tensor):
    758   if dtypes.string == tensor.dtype.base_dtype:
    759     return tensor
    760   return string_ops.as_string(tensor)
    761 
    762 
    763 class IdTableWithHashBuckets(LookupInterface):
    764   """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
    765 
    766   For example, if an instance of `IdTableWithHashBuckets` is initialized with a
    767   string-to-id table that maps:
    768 
    769   * `emerson -> 0`
    770   * `lake -> 1`
    771   * `palmer -> 2`
    772 
    773   The `IdTableWithHashBuckets` object will performs the following mapping:
    774 
    775   * `emerson -> 0`
    776   * `lake -> 1`
    777   * `palmer -> 2`
    778   * `<other term> -> bucket_id`, where bucket_id will be between `3` and
    779   `3 + num_oov_buckets - 1`, calculated by:
    780   `hash(<term>) % num_oov_buckets + vocab_size`
    781 
    782   If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
    783   the lookup result is `[0, 1, 2, 4, 7]`.
    784 
    785   If `table` is None, only out-of-vocabulary buckets are used.
    786 
    787   Example usage:
    788 
    789   ```python
    790   num_oov_buckets = 3
    791   input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
    792   table = tf.IdTableWithHashBuckets(
    793       tf.StaticHashTable(tf.TextFileIdTableInitializer(filename),
    794                          default_value),
    795       num_oov_buckets)
    796   out = table.lookup(input_tensor).
    797   table.init.run()
    798   print(out.eval())
    799   ```
    800 
    801   The hash function used for generating out-of-vocabulary buckets ID is handled
    802   by `hasher_spec`.
    803   """
    804 
    805   def __init__(self,
    806                table,
    807                num_oov_buckets,
    808                hasher_spec=FastHashSpec,
    809                name=None,
    810                key_dtype=None):
    811     """Construct a `IdTableWithHashBuckets` object.
    812 
    813     Args:
    814       table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
    815       num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
    816       hasher_spec: A `HasherSpec` to specify the hash function to use for
    817         assignation of out-of-vocabulary buckets  (optional).
    818       name: A name for the operation (optional).
    819       key_dtype: Data type of keys passed to `lookup`. Defaults to
    820         `table.key_dtype` if `table` is specified, otherwise `tf.string`.
    821         Must be string or integer, and must be castable to `table.key_dtype`.
    822 
    823     Raises:
    824       ValueError: when `table` in None and `num_oov_buckets` is not positive.
    825       TypeError: when `hasher_spec` is invalid.
    826     """
    827     # If a name ends with a '/' it is a "name scope", remove all trailing '/'
    828     # characters to use as table name.
    829     if name:
    830       name = name.rstrip("/")
    831     if table:
    832       if key_dtype is None:
    833         key_dtype = table.key_dtype
    834       supported_table_key_dtypes = (dtypes.int64, dtypes.string)
    835       if table.key_dtype not in supported_table_key_dtypes:
    836         raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
    837                         (supported_table_key_dtypes, key_dtype))
    838       if table.key_dtype.is_integer != key_dtype.is_integer:
    839         raise TypeError("Invalid key dtype, expected %s but got %s." %
    840                         ("integer" if key_dtype.is_integer else "non-integer",
    841                          table.key_dtype))
    842       if table.value_dtype != dtypes.int64:
    843         raise TypeError("Invalid value dtype, expected %s but got %s." %
    844                         (dtypes.int64, table.value_dtype))
    845       self._table = table
    846       name = name or self._table.name
    847     else:
    848       if num_oov_buckets <= 0:
    849         raise ValueError("oov_buckets must be > 0 if no table is supplied.")
    850       key_dtype = dtypes.string if key_dtype is None else key_dtype
    851       self._table = None
    852       name = name or "hash_bucket"
    853     if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
    854       raise TypeError(
    855           "Invalid key_dtype, expected integer or string, got %s." % key_dtype)
    856     self._num_oov_buckets = num_oov_buckets
    857 
    858     if not isinstance(hasher_spec, HasherSpec):
    859       raise TypeError(
    860           "hasher_spec must be of type HasherSpec, got %s" % hasher_spec)
    861     self._hasher_spec = hasher_spec
    862     if name:
    863       self._table_name = name.split("/")[-1]
    864     else:
    865       self._table_name = None
    866     super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
    867 
    868   def _create_resource(self):
    869     if self._table is not None:
    870       return self._table._create_resource()  # pylint: disable=protected-access
    871     return None
    872 
    873   def _initialize(self):
    874     if self._table is not None:
    875       return self._table._initialize()  # pylint: disable=protected-access
    876     with ops.name_scope(None, "init"):
    877       return control_flow_ops.no_op()
    878 
    879   @property
    880   def initializer(self):
    881     if self._table is not None:
    882       return self._table._init_op  # pylint: disable=protected-access
    883     with ops.name_scope(None, "init"):
    884       return control_flow_ops.no_op()
    885 
    886   @property
    887   @deprecated("2018-12-15", "Use `initializer` instead.")
    888   def init(self):
    889     return self.initializer
    890 
    891   @property
    892   def resource_handle(self):
    893     if self._table is not None:
    894       return self._table.resource_handle
    895     return None
    896 
    897   @property
    898   def name(self):
    899     return self._table_name
    900 
    901   def size(self, name=None):
    902     """Compute the number of elements in this table."""
    903     with ops.name_scope(name, "%s_Size" % self.name):
    904       if self._table:
    905         tsize = self._table.size()
    906       else:
    907         tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
    908       return tsize + self._num_oov_buckets
    909 
    910   def _get_string_to_hash_bucket_fn(self, hasher_spec):
    911     """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
    912     if not isinstance(hasher_spec, HasherSpec):
    913       raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
    914     if hasher_spec.hasher == "fasthash":
    915       return string_ops.string_to_hash_bucket_fast
    916     if hasher_spec.hasher == "legacy":
    917       return string_ops.string_to_hash_bucket
    918     if hasher_spec.hasher == "stronghash":
    919       return functools.partial(
    920           string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
    921     raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
    922 
    923   def lookup(self, keys, name=None):
    924     """Looks up `keys` in the table, outputs the corresponding values.
    925 
    926     It assigns out-of-vocabulary keys to buckets based in their hashes.
    927 
    928     Args:
    929       keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
    930       name: Optional name for the op.
    931 
    932     Returns:
    933       A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
    934 
    935     Raises:
    936       TypeError: when `keys` doesn't match the table key data type.
    937     """
    938     if keys.dtype.base_dtype != self._key_dtype:
    939       raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
    940                       (self._key_dtype, keys.dtype))
    941     values = keys
    942     if isinstance(keys, sparse_tensor.SparseTensor):
    943       values = keys.values
    944     if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
    945       values = math_ops.cast(values, dtypes.int64)
    946 
    947     if self._num_oov_buckets == 0:
    948       ids = self._table.lookup(values, name=name)
    949     else:
    950       # TODO(yleon): Consider moving this functionality to its own kernel.
    951       with ops.name_scope(name, "%s_Lookup" % self.name):
    952         str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
    953             self._hasher_spec)
    954         buckets = str_to_hash_bucket(
    955             _as_string(values),
    956             num_buckets=self._num_oov_buckets,
    957             name="hash_bucket")
    958         if self._table:
    959           ids = self._table.lookup(values)
    960           buckets = math_ops.add(buckets, self._table.size())
    961           is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
    962           ids = array_ops.where(is_id_non_default, ids, buckets)
    963         else:
    964           ids = buckets
    965     if isinstance(keys, sparse_tensor.SparseTensor):
    966       return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
    967     return ids
    968 
    969 
    970 @tf_export("lookup.StaticVocabularyTable", v1=[])
    971 class StaticVocabularyTable(LookupInterface):
    972   """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
    973 
    974   For example, if an instance of `StaticVocabularyTable` is initialized with a
    975   string-to-id initializer that maps:
    976 
    977   * `emerson -> 0`
    978   * `lake -> 1`
    979   * `palmer -> 2`
    980 
    981   The `Vocabulary` object will performs the following mapping:
    982 
    983   * `emerson -> 0`
    984   * `lake -> 1`
    985   * `palmer -> 2`
    986   * `<other term> -> bucket_id`, where bucket_id will be between `3` and
    987   `3 + num_oov_buckets - 1`, calculated by:
    988   `hash(<term>) % num_oov_buckets + vocab_size`
    989 
    990   If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
    991   the lookup result is `[0, 1, 2, 4, 7]`.
    992 
    993   If `initializer` is None, only out-of-vocabulary buckets are used.
    994 
    995   Example usage:
    996 
    997   ```python
    998   num_oov_buckets = 3
    999   input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
   1000   table = tf.lookup.StaticVocabularyTable(
   1001       tf.TextFileIdTableInitializer(filename), num_oov_buckets)
   1002   out = table.lookup(input_tensor).
   1003   table.init.run()
   1004   print(out.eval())
   1005   ```
   1006 
   1007   The hash function used for generating out-of-vocabulary buckets ID is
   1008   Fingerprint64.
   1009   """
   1010 
   1011   def __init__(self,
   1012                initializer,
   1013                num_oov_buckets,
   1014                lookup_key_dtype=None,
   1015                name=None):
   1016     """Construct a `StaticVocabularyTable` object.
   1017 
   1018     Args:
   1019       initializer: A TableInitializerBase object that contains the data used to
   1020         initialize the table. If None, then we only use out-of-vocab buckets.
   1021       num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
   1022         be greater than zero.
   1023       lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
   1024         `initializer.key_dtype` if `initializer` is specified, otherwise
   1025         `tf.string`. Must be string or integer, and must be castable to
   1026         `initializer.key_dtype`.
   1027       name: A name for the operation (optional).
   1028 
   1029     Raises:
   1030       ValueError: when `num_oov_buckets` is not positive.
   1031       TypeError: when lookup_key_dtype or initializer.key_dtype are not
   1032         integer or string. Also when initializer.value_dtype != int64.
   1033     """
   1034     if num_oov_buckets <= 0:
   1035       raise ValueError("oov_buckets must be > 0.")
   1036     # If a name ends with a '/' it is a "name scope", remove all trailing '/'
   1037     # characters to use as table name.
   1038     if name:
   1039       name = name.rstrip("/")
   1040     if initializer:
   1041       if lookup_key_dtype is None:
   1042         lookup_key_dtype = initializer.key_dtype
   1043       supported_table_key_dtypes = (dtypes.int64, dtypes.string)
   1044       if initializer.key_dtype not in supported_table_key_dtypes:
   1045         raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
   1046                         (supported_table_key_dtypes, initializer.key_dtype))
   1047       if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
   1048         raise TypeError(
   1049             "Invalid key dtype, expected %s but got %s." %
   1050             ("integer" if lookup_key_dtype.is_integer else "non-integer",
   1051              initializer.key_dtype))
   1052       if initializer.value_dtype != dtypes.int64:
   1053         raise TypeError("Invalid value dtype, expected %s but got %s." %
   1054                         (dtypes.int64, initializer.value_dtype))
   1055       self._table = HashTable(initializer, default_value=-1)
   1056       name = name or self._table.name
   1057     else:
   1058       lookup_key_dtype = dtypes.string
   1059       self._table = None
   1060       name = name or "hash_bucket"
   1061     if (not lookup_key_dtype.is_integer) and (dtypes.string !=
   1062                                               lookup_key_dtype):
   1063       raise TypeError("Invalid key_dtype, expected integer or string, got %s." %
   1064                       lookup_key_dtype)
   1065     self._num_oov_buckets = num_oov_buckets
   1066 
   1067     self._table_name = None
   1068     if name is not None:
   1069       self._table_name = name.split("/")[-1]
   1070     super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
   1071 
   1072   def _create_resource(self):
   1073     if self._table is not None:
   1074       return self._table._create_resource()  # pylint: disable=protected-access
   1075     return None
   1076 
   1077   def _initialize(self):
   1078     if self._table is not None:
   1079       return self._table._initialize()  # pylint: disable=protected-access
   1080     with ops.name_scope(None, "init"):
   1081       return control_flow_ops.no_op()
   1082 
   1083   @property
   1084   def resource_handle(self):
   1085     if self._table is not None:
   1086       return self._table.resource_handle
   1087     return None
   1088 
   1089   @property
   1090   def name(self):
   1091     return self._table_name
   1092 
   1093   def size(self, name=None):
   1094     """Compute the number of elements in this table."""
   1095     with ops.name_scope(name, "%s_Size" % self.name):
   1096       if self._table:
   1097         tsize = self._table.size()
   1098       else:
   1099         tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
   1100       return tsize + self._num_oov_buckets
   1101 
   1102   def lookup(self, keys, name=None):
   1103     """Looks up `keys` in the table, outputs the corresponding values.
   1104 
   1105     It assigns out-of-vocabulary keys to buckets based in their hashes.
   1106 
   1107     Args:
   1108       keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
   1109       name: Optional name for the op.
   1110 
   1111     Returns:
   1112       A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
   1113 
   1114     Raises:
   1115       TypeError: when `keys` doesn't match the table key data type.
   1116     """
   1117     if keys.dtype.base_dtype != self._key_dtype:
   1118       raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
   1119                       (self._key_dtype, keys.dtype))
   1120     values = keys
   1121     if isinstance(keys, sparse_tensor.SparseTensor):
   1122       values = keys.values
   1123     if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
   1124       values = math_ops.cast(values, dtypes.int64)
   1125 
   1126     # TODO(yleon): Consider moving this functionality to its own kernel.
   1127     with ops.name_scope(name, "%s_Lookup" % self.name):
   1128       buckets = string_ops.string_to_hash_bucket_fast(
   1129           _as_string(values),
   1130           num_buckets=self._num_oov_buckets,
   1131           name="hash_bucket")
   1132       if self._table:
   1133         ids = self._table.lookup(values)
   1134         buckets = math_ops.add(buckets, self._table.size())
   1135         is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
   1136         ids = array_ops.where(is_id_non_default, ids, buckets)
   1137       else:
   1138         ids = buckets
   1139     if isinstance(keys, sparse_tensor.SparseTensor):
   1140       return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
   1141     return ids
   1142 
   1143 
   1144 @tf_export(v1=["lookup.StaticVocabularyTable"])
   1145 class StaticVocabularyTableV1(StaticVocabularyTable):
   1146 
   1147   @property
   1148   def initializer(self):
   1149     if self._table is not None:
   1150       return self._table._init_op  # pylint: disable=protected-access
   1151     with ops.name_scope(None, "init"):
   1152       return control_flow_ops.no_op()
   1153 
   1154 
   1155 def index_table_from_file(vocabulary_file=None,
   1156                           num_oov_buckets=0,
   1157                           vocab_size=None,
   1158                           default_value=-1,
   1159                           hasher_spec=FastHashSpec,
   1160                           key_dtype=dtypes.string,
   1161                           name=None,
   1162                           key_column_index=TextFileIndex.WHOLE_LINE,
   1163                           value_column_index=TextFileIndex.LINE_NUMBER,
   1164                           delimiter="\t"):
   1165   """Returns a lookup table that converts a string tensor into int64 IDs.
   1166 
   1167   This operation constructs a lookup table to convert tensor of strings into
   1168   int64 IDs. The mapping can be initialized from a vocabulary file specified in
   1169   `vocabulary_file`, where the whole line is the key and the zero-based line
   1170   number is the ID.
   1171 
   1172   Any lookup of an out-of-vocabulary token will return a bucket ID based on its
   1173   hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
   1174   `default_value`.
   1175   The bucket ID range is
   1176   `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
   1177 
   1178   The underlying table must be initialized by calling
   1179   `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
   1180 
   1181   To specify multi-column vocabulary files, use key_column_index and
   1182   value_column_index and delimiter.
   1183 
   1184   - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
   1185     expects data type int64.
   1186   - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
   1187     type string.
   1188   - A value >=0 means use the index (starting at zero) of the split line based
   1189     on `delimiter`.
   1190 
   1191   Sample Usages:
   1192 
   1193   If we have a vocabulary file "test.txt" with the following content:
   1194 
   1195   ```
   1196   emerson
   1197   lake
   1198   palmer
   1199   ```
   1200 
   1201   ```python
   1202   features = tf.constant(["emerson", "lake", "and", "palmer"])
   1203   table = tf.lookup.index_table_from_file(
   1204       vocabulary_file="test.txt", num_oov_buckets=1)
   1205   ids = table.lookup(features)
   1206   ...
   1207   tf.tables_initializer().run()
   1208 
   1209   ids.eval()  ==> [0, 1, 3, 2]  # where 3 is the out-of-vocabulary bucket
   1210   ```
   1211 
   1212   Args:
   1213     vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
   1214     num_oov_buckets: The number of out-of-vocabulary buckets.
   1215     vocab_size: Number of the elements in the vocabulary, if known.
   1216     default_value: The value to use for out-of-vocabulary feature values.
   1217       Defaults to -1.
   1218     hasher_spec: A `HasherSpec` to specify the hash function to use for
   1219       assignation of out-of-vocabulary buckets.
   1220     key_dtype: The `key` data type.
   1221     name: A name for this op (optional).
   1222     key_column_index: The column index from the text file to get the `key`
   1223       values from. The default is to use the whole line content.
   1224     value_column_index: The column index from the text file to get the `value`
   1225       values from. The default is to use the line number, starting from zero.
   1226     delimiter: The delimiter to separate fields in a line.
   1227 
   1228   Returns:
   1229     The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
   1230 
   1231   Raises:
   1232     ValueError: If `vocabulary_file` is not set.
   1233     ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
   1234       than zero.
   1235   """
   1236   if vocabulary_file is None or (
   1237       isinstance(vocabulary_file, six.string_types) and not vocabulary_file):
   1238     raise ValueError("vocabulary_file must be specified and must not be empty.")
   1239   if num_oov_buckets < 0:
   1240     raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
   1241                      % num_oov_buckets)
   1242   if vocab_size is not None and vocab_size < 1:
   1243     vocab_file_value = vocabulary_file
   1244     if isinstance(vocabulary_file, ops.Tensor):
   1245       vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
   1246     raise ValueError("vocab_size must be greater than 0, got %d. "
   1247                      "vocabulary_file: %s" % (vocab_size, vocab_file_value))
   1248   if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
   1249     raise TypeError("Only integer and string keys are supported.")
   1250 
   1251   with ops.name_scope(name, "string_to_index"):
   1252     table = None
   1253     with ops.name_scope(None, "hash_table"):
   1254       init = TextFileIdTableInitializer(
   1255           vocabulary_file,
   1256           vocab_size=vocab_size,
   1257           key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
   1258           name="table_init",
   1259           key_column_index=key_column_index,
   1260           value_column_index=value_column_index,
   1261           delimiter=delimiter)
   1262 
   1263       table = StaticHashTableV1(init, default_value)
   1264     if num_oov_buckets:
   1265       table = IdTableWithHashBuckets(
   1266           table,
   1267           num_oov_buckets=num_oov_buckets,
   1268           hasher_spec=hasher_spec,
   1269           key_dtype=key_dtype)
   1270 
   1271     return table
   1272 
   1273 
   1274 def index_table_from_tensor(vocabulary_list,
   1275                             num_oov_buckets=0,
   1276                             default_value=-1,
   1277                             hasher_spec=FastHashSpec,
   1278                             dtype=dtypes.string,
   1279                             name=None):
   1280   """Returns a lookup table that converts a string tensor into int64 IDs.
   1281 
   1282   This operation constructs a lookup table to convert tensor of strings into
   1283   int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
   1284   tensor where each element is a key and corresponding index within the tensor
   1285   is the value.
   1286 
   1287   Any lookup of an out-of-vocabulary token will return a bucket ID based on its
   1288   hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
   1289   `default_value`. The bucket ID range is
   1290   `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
   1291 
   1292   The underlying table must be initialized by calling
   1293   `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
   1294 
   1295   Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
   1296   the table initializer op, it will throw a `FailedPreconditionError`.
   1297 
   1298   Sample Usages:
   1299 
   1300   ```python
   1301   vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
   1302   table = tf.lookup.index_table_from_tensor(
   1303       vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
   1304   features = tf.constant(["emerson", "lake", "and", "palmer"])
   1305   ids = table.lookup(features)
   1306   ...
   1307   tf.tables_initializer().run()
   1308 
   1309   ids.eval()  ==> [0, 1, 4, 2]
   1310   ```
   1311 
   1312   Args:
   1313     vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
   1314       indices. The type of this object must be castable to `dtype`.
   1315     num_oov_buckets: The number of out-of-vocabulary buckets.
   1316     default_value: The value to use for out-of-vocabulary feature values.
   1317       Defaults to -1.
   1318     hasher_spec: A `HasherSpec` to specify the hash function to use for
   1319       assignment of out-of-vocabulary buckets.
   1320     dtype: The type of values passed to `lookup`. Only string and integers are
   1321       supported.
   1322     name: A name for this op (optional).
   1323 
   1324   Returns:
   1325     The lookup table to map an input `Tensor` to index `int64` `Tensor`.
   1326 
   1327   Raises:
   1328     ValueError: If `vocabulary_list` is invalid.
   1329     ValueError: If `num_oov_buckets` is negative.
   1330   """
   1331   if vocabulary_list is None:
   1332     raise ValueError("vocabulary_list must be specified.")
   1333 
   1334   if num_oov_buckets < 0:
   1335     raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
   1336                      % num_oov_buckets)
   1337 
   1338   if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
   1339     raise TypeError("Only integer and string keys are supported.")
   1340 
   1341   with ops.name_scope(name, "string_to_index"):
   1342     keys = ops.convert_to_tensor(vocabulary_list)
   1343     if keys.dtype.is_integer != dtype.is_integer:
   1344       raise ValueError("Expected %s, got %s." %
   1345                        ("integer"
   1346                         if dtype.is_integer else "non-integer", keys.dtype))
   1347     if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
   1348       raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
   1349     num_elements = array_ops.size(keys)
   1350     values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
   1351 
   1352     with ops.name_scope(None, "hash_table"):
   1353       table_keys = math_ops.cast(
   1354           keys, dtypes.int64) if keys.dtype.is_integer else keys
   1355       init = KeyValueTensorInitializer(
   1356           table_keys,
   1357           values,
   1358           table_keys.dtype.base_dtype,
   1359           dtypes.int64,
   1360           name="table_init")
   1361       table = StaticHashTableV1(init, default_value)
   1362     if num_oov_buckets:
   1363       table = IdTableWithHashBuckets(
   1364           table,
   1365           num_oov_buckets=num_oov_buckets,
   1366           hasher_spec=hasher_spec,
   1367           key_dtype=dtype)
   1368     return table
   1369 
   1370 
   1371 def index_to_string_table_from_file(vocabulary_file,
   1372                                     vocab_size=None,
   1373                                     default_value="UNK",
   1374                                     name=None,
   1375                                     key_column_index=TextFileIndex.LINE_NUMBER,
   1376                                     value_column_index=TextFileIndex.WHOLE_LINE,
   1377                                     delimiter="\t"):
   1378   """Returns a lookup table that maps a `Tensor` of indices into strings.
   1379 
   1380   This operation constructs a lookup table to map int64 indices into string
   1381   values. The table is initialized from a vocabulary file specified in
   1382   `vocabulary_file`, where the whole line is the value and the
   1383   zero-based line number is the index.
   1384 
   1385   Any input which does not have a corresponding index in the vocabulary file
   1386   (an out-of-vocabulary entry) is assigned the `default_value`
   1387 
   1388   The underlying table must be initialized by calling
   1389   `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
   1390 
   1391   To specify multi-column vocabulary files, use key_column_index and
   1392   value_column_index and delimiter.
   1393 
   1394   - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
   1395     expects data type int64.
   1396   - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
   1397     type string.
   1398   - A value >=0 means use the index (starting at zero) of the split line based
   1399     on `delimiter`.
   1400 
   1401   Sample Usages:
   1402 
   1403   If we have a vocabulary file "test.txt" with the following content:
   1404 
   1405   ```
   1406   emerson
   1407   lake
   1408   palmer
   1409   ```
   1410 
   1411   ```python
   1412   indices = tf.constant([1, 5], tf.int64)
   1413   table = tf.lookup.index_to_string_table_from_file(
   1414       vocabulary_file="test.txt", default_value="UNKNOWN")
   1415   values = table.lookup(indices)
   1416   ...
   1417   tf.tables_initializer().run()
   1418 
   1419   values.eval() ==> ["lake", "UNKNOWN"]
   1420   ```
   1421 
   1422   Args:
   1423     vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
   1424     vocab_size: Number of the elements in the vocabulary, if known.
   1425     default_value: The value to use for out-of-vocabulary indices.
   1426     name: A name for this op (optional).
   1427     key_column_index: The column index from the text file to get the `key`
   1428       values from. The default is to use the line number, starting from zero.
   1429     value_column_index: The column index from the text file to get the `value`
   1430       values from. The default is to use the whole line content.
   1431     delimiter: The delimiter to separate fields in a line.
   1432 
   1433   Returns:
   1434     The lookup table to map a string values associated to a given index `int64`
   1435     `Tensors`.
   1436 
   1437   Raises:
   1438     ValueError: when `vocabulary_file` is empty.
   1439     ValueError: when `vocab_size` is invalid.
   1440   """
   1441   if vocabulary_file is None or (
   1442       isinstance(vocabulary_file, six.string_types) and not vocabulary_file):
   1443     raise ValueError("vocabulary_file must be specified and must not be empty.")
   1444 
   1445   if vocab_size is not None and vocab_size < 1:
   1446     raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
   1447 
   1448   with ops.name_scope(name, "index_to_string"):
   1449     init = TextFileStringTableInitializer(
   1450         vocabulary_file,
   1451         vocab_size=vocab_size,
   1452         name="table_init",
   1453         key_column_index=key_column_index,
   1454         value_column_index=value_column_index,
   1455         delimiter=delimiter)
   1456 
   1457     # TODO(yleon): Use a more effienct structure.
   1458     return StaticHashTableV1(init, default_value)
   1459 
   1460 
   1461 def index_to_string_table_from_tensor(vocabulary_list,
   1462                                       default_value="UNK",
   1463                                       name=None):
   1464   """Returns a lookup table that maps a `Tensor` of indices into strings.
   1465 
   1466   This operation constructs a lookup table to map int64 indices into string
   1467   values. The mapping is initialized from a string `vocabulary_list` 1-D
   1468   `Tensor` where each element is a value and the corresponding index within the
   1469   tensor is the key.
   1470 
   1471   Any input which does not have a corresponding index in 'vocabulary_list'
   1472   (an out-of-vocabulary entry) is assigned the `default_value`
   1473 
   1474   The underlying table must be initialized by calling
   1475   `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
   1476 
   1477   Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
   1478   the table initializer op, it will throw a `FailedPreconditionError`.
   1479 
   1480   Sample Usages:
   1481 
   1482   ```python
   1483   vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
   1484   indices = tf.constant([1, 5], tf.int64)
   1485   table = tf.lookup.index_to_string_table_from_tensor(
   1486       vocabulary_list, default_value="UNKNOWN")
   1487   values = table.lookup(indices)
   1488   ...
   1489   tf.tables_initializer().run()
   1490 
   1491   values.eval() ==> ["lake", "UNKNOWN"]
   1492   ```
   1493 
   1494   Args:
   1495     vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
   1496       from indices.
   1497     default_value: The value to use for out-of-vocabulary indices.
   1498     name: A name for this op (optional).
   1499 
   1500   Returns:
   1501     The lookup table to map a string values associated to a given index `int64`
   1502     `Tensors`.
   1503 
   1504   Raises:
   1505     ValueError: when `vocabulary_list` is not set.
   1506   """
   1507 
   1508   if vocabulary_list is None:
   1509     raise ValueError("vocabulary_list must be specified.")
   1510 
   1511   with ops.name_scope(name, "index_to_string"):
   1512     vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
   1513     num_elements = array_ops.size(vocabulary_list)
   1514     keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
   1515 
   1516     init = KeyValueTensorInitializer(
   1517         keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
   1518     # TODO(yleon): Use a more effienct structure.
   1519     return StaticHashTableV1(init, default_value)
   1520 
   1521 
   1522 class MutableHashTable(LookupInterface):
   1523   """A generic mutable hash table implementation.
   1524 
   1525   Data can be inserted by calling the insert method and removed by calling the
   1526   remove method. It does not support initialization via the init method.
   1527 
   1528   Example usage:
   1529 
   1530   ```python
   1531   table = tf.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64,
   1532                                      default_value=-1)
   1533   sess.run(table.insert(keys, values))
   1534   out = table.lookup(query_keys)
   1535   print(out.eval())
   1536   ```
   1537   """
   1538 
   1539   def __init__(self,
   1540                key_dtype,
   1541                value_dtype,
   1542                default_value,
   1543                name="MutableHashTable",
   1544                checkpoint=True):
   1545     """Creates an empty `MutableHashTable` object.
   1546 
   1547     Creates a table, the type of its keys and values are specified by key_dtype
   1548     and value_dtype, respectively.
   1549 
   1550     Args:
   1551       key_dtype: the type of the key tensors.
   1552       value_dtype: the type of the value tensors.
   1553       default_value: The value to use if a key is missing in the table.
   1554       name: A name for the operation (optional).
   1555       checkpoint: if True, the contents of the table are saved to and restored
   1556         from checkpoints. If `shared_name` is empty for a checkpointed table, it
   1557         is shared using the table node name.
   1558 
   1559     Returns:
   1560       A `MutableHashTable` object.
   1561 
   1562     Raises:
   1563       ValueError: If checkpoint is True and no name was specified.
   1564     """
   1565     self._default_value = ops.convert_to_tensor(
   1566         default_value, dtype=value_dtype)
   1567     self._value_shape = self._default_value.get_shape()
   1568     self._checkpoint = checkpoint
   1569     self._key_dtype = key_dtype
   1570     self._value_dtype = value_dtype
   1571     self._name = name
   1572 
   1573     self._shared_name = None
   1574     if context.executing_eagerly():
   1575       # TODO(allenl): This will leak memory due to kernel caching by the
   1576       # shared_name attribute value (but is better than the alternative of
   1577       # sharing everything by default when executing eagerly; hopefully creating
   1578       # tables in a loop is uncommon).
   1579       # TODO(rohanj): Use context.shared_name() instead.
   1580       self._shared_name = "table_%d" % (ops.uid(),)
   1581     super(MutableHashTable, self).__init__(key_dtype, value_dtype)
   1582 
   1583     self._resource_handle = self._create_resource()
   1584     if checkpoint:
   1585       saveable = MutableHashTable._Saveable(self, name)
   1586       if not context.executing_eagerly():
   1587         ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   1588 
   1589   def _create_resource(self):
   1590     # The table must be shared if checkpointing is requested for multi-worker
   1591     # training to work correctly. Use the node name if no shared_name has been
   1592     # explicitly specified.
   1593     use_node_name_sharing = self._checkpoint and self._shared_name is None
   1594     if self._default_value.get_shape().ndims == 0:
   1595       table_ref = gen_lookup_ops.mutable_hash_table_v2(
   1596           shared_name=self._shared_name,
   1597           use_node_name_sharing=use_node_name_sharing,
   1598           key_dtype=self._key_dtype,
   1599           value_dtype=self._value_dtype,
   1600           name=self._name)
   1601     else:
   1602       table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
   1603           shared_name=self._shared_name,
   1604           use_node_name_sharing=use_node_name_sharing,
   1605           key_dtype=self._key_dtype,
   1606           value_dtype=self._value_dtype,
   1607           value_shape=self._default_value.get_shape(),
   1608           name=self._name)
   1609 
   1610     if context.executing_eagerly():
   1611       self._table_name = None
   1612     else:
   1613       self._table_name = table_ref.op.name.split("/")[-1]
   1614     return table_ref
   1615 
   1616   @property
   1617   def name(self):
   1618     return self._table_name
   1619 
   1620   def size(self, name=None):
   1621     """Compute the number of elements in this table.
   1622 
   1623     Args:
   1624       name: A name for the operation (optional).
   1625 
   1626     Returns:
   1627       A scalar tensor containing the number of elements in this table.
   1628     """
   1629     with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
   1630       with ops.colocate_with(self.resource_handle):
   1631         return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
   1632 
   1633   def remove(self, keys, name=None):
   1634     """Removes `keys` and its associated values from the table.
   1635 
   1636     If a key is not present in the table, it is silently ignored.
   1637 
   1638     Args:
   1639       keys: Keys to remove. Can be a tensor of any shape. Must match the table's
   1640         key type.
   1641       name: A name for the operation (optional).
   1642 
   1643     Returns:
   1644       The created Operation.
   1645 
   1646     Raises:
   1647       TypeError: when `keys` do not match the table data types.
   1648     """
   1649     if keys.dtype != self._key_dtype:
   1650       raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
   1651                       (self._key_dtype, keys.dtype))
   1652 
   1653     with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
   1654                         (self.resource_handle, keys, self._default_value)):
   1655       op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
   1656 
   1657     return op
   1658 
   1659   def lookup(self, keys, name=None):
   1660     """Looks up `keys` in a table, outputs the corresponding values.
   1661 
   1662     The `default_value` is used for keys not present in the table.
   1663 
   1664     Args:
   1665       keys: Keys to look up. Can be a tensor of any shape. Must match the
   1666         table's key_dtype.
   1667       name: A name for the operation (optional).
   1668 
   1669     Returns:
   1670       A tensor containing the values in the same shape as `keys` using the
   1671         table's value type.
   1672 
   1673     Raises:
   1674       TypeError: when `keys` do not match the table data types.
   1675     """
   1676     with ops.name_scope(name, "%s_lookup_table_find" % self.name,
   1677                         (self.resource_handle, keys, self._default_value)):
   1678       keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
   1679       with ops.colocate_with(self.resource_handle):
   1680         values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
   1681                                                      self._default_value)
   1682     return values
   1683 
   1684   def insert(self, keys, values, name=None):
   1685     """Associates `keys` with `values`.
   1686 
   1687     Args:
   1688       keys: Keys to insert. Can be a tensor of any shape. Must match the table's
   1689         key type.
   1690       values: Values to be associated with keys. Must be a tensor of the same
   1691         shape as `keys` and match the table's value type.
   1692       name: A name for the operation (optional).
   1693 
   1694     Returns:
   1695       The created Operation.
   1696 
   1697     Raises:
   1698       TypeError: when `keys` or `values` doesn't match the table data
   1699         types.
   1700     """
   1701     with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
   1702                         [self.resource_handle, keys, values]):
   1703       keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
   1704       values = ops.convert_to_tensor(values, self._value_dtype, name="values")
   1705       with ops.colocate_with(self.resource_handle):
   1706         # pylint: disable=protected-access
   1707         op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
   1708                                                    values)
   1709     return op
   1710 
   1711   def export(self, name=None):
   1712     """Returns tensors of all keys and values in the table.
   1713 
   1714     Args:
   1715       name: A name for the operation (optional).
   1716 
   1717     Returns:
   1718       A pair of tensors with the first tensor containing all keys and the
   1719         second tensors containing all values in the table.
   1720     """
   1721     with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
   1722                         [self.resource_handle]):
   1723       with ops.colocate_with(self.resource_handle):
   1724         exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
   1725             self.resource_handle, self._key_dtype, self._value_dtype)
   1726     return exported_keys, exported_values
   1727 
   1728   def _gather_saveables_for_checkpoint(self):
   1729     """For object-based checkpointing."""
   1730     return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
   1731 
   1732   class _Saveable(BaseSaverBuilder.SaveableObject):
   1733     """SaveableObject implementation for MutableHashTable."""
   1734 
   1735     def __init__(self, table, name):
   1736       tensors = table.export()
   1737       specs = [
   1738           BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
   1739           BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
   1740       ]
   1741       # pylint: disable=protected-access
   1742       super(MutableHashTable._Saveable, self).__init__(table, specs, name)
   1743 
   1744     def restore(self, restored_tensors, restored_shapes, name=None):
   1745       del restored_shapes  # unused
   1746       # pylint: disable=protected-access
   1747       with ops.name_scope(name, "%s_table_restore" % self.name):
   1748         with ops.colocate_with(self.op.resource_handle):
   1749           return gen_lookup_ops.lookup_table_import_v2(
   1750               self.op.resource_handle, restored_tensors[0], restored_tensors[1])
   1751 
   1752 
   1753 @tf_export("lookup.experimental.DenseHashTable")
   1754 class DenseHashTable(LookupInterface):
   1755   """A generic mutable hash table implementation using tensors as backing store.
   1756 
   1757   Data can be inserted by calling the insert method and removed by calling the
   1758   remove method. It does not support initialization via the init method.
   1759 
   1760   It uses "open addressing" with quadratic reprobing to resolve collisions.
   1761   Compared to `MutableHashTable` the insert, remove and lookup operations in a
   1762   `DenseHashTable` are typically faster, but memory usage can be higher.
   1763   However, `DenseHashTable` does not require additional memory for
   1764   temporary tensors created during checkpointing and restore operations.
   1765 
   1766   Example usage:
   1767 
   1768   ```python
   1769   table = tf.lookup.DenseHashTable(key_dtype=tf.int64,
   1770                                    value_dtype=tf.int64,
   1771                                    default_value=-1,
   1772                                    empty_key=0,
   1773                                    deleted_key=-1)
   1774 
   1775   sess.run(table.insert(keys, values))
   1776   out = table.lookup(query_keys)
   1777   print(out.eval())
   1778   ```
   1779   """
   1780 
   1781   # TODO(andreasst): consider extracting common code with MutableHashTable into
   1782   # a common superclass.
   1783   def __init__(self,
   1784                key_dtype,
   1785                value_dtype,
   1786                default_value,
   1787                empty_key,
   1788                deleted_key,
   1789                initial_num_buckets=None,
   1790                name="MutableDenseHashTable",
   1791                checkpoint=True):
   1792     """Creates an empty `DenseHashTable` object.
   1793 
   1794     Creates a table, the type of its keys and values are specified by key_dtype
   1795     and value_dtype, respectively.
   1796 
   1797     Args:
   1798       key_dtype: the type of the key tensors.
   1799       value_dtype: the type of the value tensors.
   1800       default_value: The value to use if a key is missing in the table.
   1801       empty_key: the key to use to represent empty buckets internally. Must not
   1802         be used in insert, remove or lookup operations.
   1803       deleted_key: the key to use to represent deleted buckets internally. Must
   1804         not be used in insert, remove or lookup operations and be different from
   1805         the empty_key.
   1806       initial_num_buckets: the initial number of buckets.
   1807       name: A name for the operation (optional).
   1808       checkpoint: if True, the contents of the table are saved to and restored
   1809         from checkpoints. If `shared_name` is empty for a checkpointed table, it
   1810         is shared using the table node name.
   1811 
   1812     Returns:
   1813       A `DenseHashTable` object.
   1814 
   1815     Raises:
   1816       ValueError: If checkpoint is True and no name was specified.
   1817     """
   1818     self._default_value = ops.convert_to_tensor(
   1819         default_value, dtype=value_dtype, name="default_value")
   1820     self._key_dtype = key_dtype
   1821     self._value_dtype = value_dtype
   1822     self._initial_num_buckets = initial_num_buckets
   1823     self._value_shape = self._default_value.get_shape()
   1824     self._checkpoint = checkpoint
   1825     self._name = name
   1826 
   1827     self._empty_key = ops.convert_to_tensor(
   1828         empty_key, dtype=key_dtype, name="empty_key")
   1829     self._deleted_key = ops.convert_to_tensor(
   1830         deleted_key, dtype=key_dtype, name="deleted_key")
   1831     self._shared_name = None
   1832     if context.executing_eagerly():
   1833       # TODO(allenl): This will leak memory due to kernel caching by the
   1834       # shared_name attribute value (but is better than the alternative of
   1835       # sharing everything by default when executing eagerly; hopefully creating
   1836       # tables in a loop is uncommon).
   1837       # TODO(rohanj): Use context.shared_name() instead.
   1838       self._shared_name = "table_%d" % (ops.uid(),)
   1839     super(DenseHashTable, self).__init__(key_dtype, value_dtype)
   1840 
   1841     self._resource_handle = self._create_resource()
   1842     if checkpoint:
   1843       saveable = DenseHashTable._Saveable(self, name)
   1844       if not context.executing_eagerly():
   1845         ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   1846 
   1847   def _create_resource(self):
   1848     # The table must be shared if checkpointing is requested for multi-worker
   1849     # training to work correctly. Use the node name if no shared_name has been
   1850     # explicitly specified.
   1851     use_node_name_sharing = self._checkpoint and self._shared_name is None
   1852     table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
   1853         empty_key=self._empty_key,
   1854         deleted_key=self._deleted_key,
   1855         shared_name=self._shared_name,
   1856         use_node_name_sharing=use_node_name_sharing,
   1857         value_dtype=self._value_dtype,
   1858         value_shape=self._value_shape,
   1859         initial_num_buckets=self._initial_num_buckets,
   1860         name=self._name)
   1861     if context.executing_eagerly():
   1862       self._table_name = None
   1863     else:
   1864       self._table_name = table_ref.op.name.split("/")[-1]
   1865     return table_ref
   1866 
   1867   @property
   1868   def name(self):
   1869     return self._table_name
   1870 
   1871   def size(self, name=None):
   1872     """Compute the number of elements in this table.
   1873 
   1874     Args:
   1875       name: A name for the operation (optional).
   1876 
   1877     Returns:
   1878       A scalar tensor containing the number of elements in this table.
   1879     """
   1880     with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
   1881       with ops.colocate_with(self.resource_handle):
   1882         return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
   1883 
   1884   def lookup(self, keys, name=None):
   1885     """Looks up `keys` in a table, outputs the corresponding values.
   1886 
   1887     The `default_value` is used for keys not present in the table.
   1888 
   1889     Args:
   1890       keys: Keys to look up. Can be a tensor of any shape. Must match the
   1891         table's key_dtype.
   1892       name: A name for the operation (optional).
   1893 
   1894     Returns:
   1895       A tensor containing the values in the same shape as `keys` using the
   1896         table's value type.
   1897 
   1898     Raises:
   1899       TypeError: when `keys` do not match the table data types.
   1900     """
   1901     with ops.name_scope(name, "%s_lookup_table_find" % self.name,
   1902                         [self.resource_handle, keys]):
   1903       keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
   1904       with ops.colocate_with(self.resource_handle):
   1905         values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
   1906                                                      self._default_value)
   1907 
   1908     return values
   1909 
   1910   def insert_or_assign(self, keys, values, name=None):
   1911     """Associates `keys` with `values`.
   1912 
   1913     Args:
   1914       keys: Keys to insert. Can be a tensor of any shape. Must match the table's
   1915         key type.
   1916       values: Values to be associated with keys. Must be a tensor of the same
   1917         shape as `keys` and match the table's value type.
   1918       name: A name for the operation (optional).
   1919 
   1920     Returns:
   1921       The created Operation.
   1922 
   1923     Raises:
   1924       TypeError: when `keys` or `values` doesn't match the table data
   1925         types.
   1926     """
   1927     with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
   1928                         [self.resource_handle, keys, values]):
   1929       keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
   1930       values = ops.convert_to_tensor(
   1931           values, dtype=self._value_dtype, name="values")
   1932       with ops.colocate_with(self.resource_handle):
   1933         op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
   1934                                                    values)
   1935       return op
   1936 
   1937   def insert(self, keys, values, name=None):
   1938     """Associates `keys` with `values`.
   1939 
   1940     Args:
   1941       keys: Keys to insert. Can be a tensor of any shape. Must match the table's
   1942         key type.
   1943       values: Values to be associated with keys. Must be a tensor of the same
   1944         shape as `keys` and match the table's value type.
   1945       name: A name for the operation (optional).
   1946 
   1947     Returns:
   1948       The created Operation.
   1949 
   1950     Raises:
   1951       TypeError: when `keys` or `values` doesn't match the table data
   1952         types.
   1953     """
   1954     return self.insert_or_assign(keys, values, name)
   1955 
   1956   def erase(self, keys, name=None):
   1957     """Removes `keys` and its associated values from the table.
   1958 
   1959     If a key is not present in the table, it is silently ignored.
   1960 
   1961     Args:
   1962       keys: Keys to remove. Can be a tensor of any shape. Must match the table's
   1963         key type.
   1964       name: A name for the operation (optional).
   1965 
   1966     Returns:
   1967       The created Operation.
   1968 
   1969     Raises:
   1970       TypeError: when `keys` do not match the table data types.
   1971     """
   1972     if keys.dtype != self._key_dtype:
   1973       raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
   1974                       (self._key_dtype, keys.dtype))
   1975 
   1976     with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
   1977                         (self.resource_handle, keys, self._default_value)):
   1978       # pylint: disable=protected-access
   1979       op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
   1980 
   1981     return op
   1982 
   1983   def remove(self, keys, name=None):
   1984     """Removes `keys` and its associated values from the table.
   1985 
   1986     If a key is not present in the table, it is silently ignored.
   1987 
   1988     Args:
   1989       keys: Keys to remove. Can be a tensor of any shape. Must match the table's
   1990         key type.
   1991       name: A name for the operation (optional).
   1992 
   1993     Returns:
   1994       The created Operation.
   1995 
   1996     Raises:
   1997       TypeError: when `keys` do not match the table data types.
   1998     """
   1999     return self.erase(keys, name)
   2000 
   2001   def export(self, name=None):
   2002     """Returns tensors of all keys and values in the table.
   2003 
   2004     Args:
   2005       name: A name for the operation (optional).
   2006 
   2007     Returns:
   2008       A pair of tensors with the first tensor containing all keys and the
   2009         second tensors containing all values in the table.
   2010     """
   2011     with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
   2012                         [self.resource_handle]):
   2013       with ops.colocate_with(self.resource_handle):
   2014         exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
   2015             self.resource_handle, self._key_dtype, self._value_dtype)
   2016 
   2017     return exported_keys, exported_values
   2018 
   2019   def _gather_saveables_for_checkpoint(self):
   2020     """For object-based checkpointing."""
   2021     return {"table": functools.partial(DenseHashTable._Saveable, table=self)}
   2022 
   2023   class _Saveable(BaseSaverBuilder.SaveableObject):
   2024     """SaveableObject implementation for DenseHashTable."""
   2025 
   2026     def __init__(self, table, name):
   2027       tensors = table.export()
   2028       specs = [
   2029           BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
   2030           BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
   2031       ]
   2032       # pylint: disable=protected-access
   2033       super(DenseHashTable._Saveable, self).__init__(table, specs, name)
   2034 
   2035     def restore(self, restored_tensors, restored_shapes, name=None):
   2036       del restored_shapes  # unused
   2037       # pylint: disable=protected-access
   2038       with ops.name_scope(name, "%s_table_restore" % self.name):
   2039         with ops.colocate_with(self.op.resource_handle):
   2040           return gen_lookup_ops.lookup_table_import_v2(
   2041               self.op.resource_handle, restored_tensors[0], restored_tensors[1])
   2042 
   2043 
   2044 ops.NotDifferentiable("LookupTableFind")
   2045 ops.NotDifferentiable("LookupTableFindV2")
   2046 ops.NotDifferentiable("LookupTableInsert")
   2047 ops.NotDifferentiable("LookupTableInsertV2")
   2048 ops.NotDifferentiable("LookupTableSize")
   2049 ops.NotDifferentiable("LookupTableSizeV2")
   2050 ops.NotDifferentiable("HashTable")
   2051 ops.NotDifferentiable("HashTableV2")
   2052 ops.NotDifferentiable("InitializeTable")
   2053 ops.NotDifferentiable("InitializeTableV2")
   2054 ops.NotDifferentiable("InitializeTableFromTextFile")
   2055 ops.NotDifferentiable("InitializeTableFromTextFileV2")
   2056 ops.NotDifferentiable("MutableDenseHashTable")
   2057 ops.NotDifferentiable("MutableDenseHashTableV2")
   2058 ops.NotDifferentiable("MutableHashTable")
   2059 ops.NotDifferentiable("MutableHashTableV2")
   2060 ops.NotDifferentiable("MutableHashTableOfTensors")
   2061 ops.NotDifferentiable("MutableHashTableOfTensorsV2")
   2062