Home | History | Annotate | Download | only in lookup
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.lookup.lookup."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import os
     21 import tempfile
     22 import numpy as np
     23 import six
     24 
     25 from tensorflow.contrib import lookup
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors_impl
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import sparse_tensor
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import lookup_ops
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.training import saver
     38 from tensorflow.python.training import server_lib
     39 
     40 
     41 class HashTableOpTest(test.TestCase):
     42 
     43   def testHashTable(self):
     44     with self.test_session():
     45       default_val = -1
     46       keys = constant_op.constant(["brain", "salad", "surgery"])
     47       values = constant_op.constant([0, 1, 2], dtypes.int64)
     48       table = lookup.HashTable(
     49           lookup.KeyValueTensorInitializer(keys, values), default_val)
     50       table.init.run()
     51 
     52       self.assertAllEqual(3, table.size().eval())
     53 
     54       input_string = constant_op.constant(["brain", "salad", "tank"])
     55       output = table.lookup(input_string)
     56       self.assertAllEqual([3], output.get_shape())
     57 
     58       result = output.eval()
     59       self.assertAllEqual([0, 1, -1], result)
     60 
     61   def testHashTableFindHighRank(self):
     62     with self.test_session():
     63       default_val = -1
     64       keys = constant_op.constant(["brain", "salad", "surgery"])
     65       values = constant_op.constant([0, 1, 2], dtypes.int64)
     66       table = lookup.HashTable(
     67           lookup.KeyValueTensorInitializer(keys, values), default_val)
     68       table.init.run()
     69 
     70       self.assertAllEqual(3, table.size().eval())
     71 
     72       input_string = constant_op.constant(
     73           [["brain", "salad"], ["tank", "tarkus"]])
     74       output = table.lookup(input_string)
     75 
     76       result = output.eval()
     77       self.assertAllEqual([[0, 1], [-1, -1]], result)
     78 
     79   def testHashTableInitWithPythonArrays(self):
     80     with self.test_session():
     81       default_val = -1
     82       keys = ["brain", "salad", "surgery"]
     83       values = [0, 1, 2]
     84       table = lookup.HashTable(
     85           lookup.KeyValueTensorInitializer(
     86               keys, values, value_dtype=dtypes.int64),
     87           default_val)
     88       table.init.run()
     89 
     90       self.assertAllEqual(3, table.size().eval())
     91 
     92       input_string = constant_op.constant(["brain", "salad", "tank"])
     93       output = table.lookup(input_string)
     94 
     95       result = output.eval()
     96       self.assertAllEqual([0, 1, -1], result)
     97 
     98   def testHashTableInitWithNumPyArrays(self):
     99     with self.test_session():
    100       default_val = -1
    101       keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
    102       values = np.array([0, 1, 2], dtype=np.int64)
    103       table = lookup.HashTable(
    104           lookup.KeyValueTensorInitializer(keys, values), default_val)
    105       table.init.run()
    106 
    107       self.assertAllEqual(3, table.size().eval())
    108 
    109       input_string = constant_op.constant(["brain", "salad", "tank"])
    110       output = table.lookup(input_string)
    111 
    112       result = output.eval()
    113       self.assertAllEqual([0, 1, -1], result)
    114 
    115   def testMultipleHashTables(self):
    116     with self.test_session() as sess:
    117       default_val = -1
    118       keys = constant_op.constant(["brain", "salad", "surgery"])
    119       values = constant_op.constant([0, 1, 2], dtypes.int64)
    120 
    121       table1 = lookup.HashTable(
    122           lookup.KeyValueTensorInitializer(keys, values), default_val)
    123       table2 = lookup.HashTable(
    124           lookup.KeyValueTensorInitializer(keys, values), default_val)
    125       table3 = lookup.HashTable(
    126           lookup.KeyValueTensorInitializer(keys, values), default_val)
    127 
    128       lookup_ops.tables_initializer().run()
    129       self.assertAllEqual(3, table1.size().eval())
    130       self.assertAllEqual(3, table2.size().eval())
    131       self.assertAllEqual(3, table3.size().eval())
    132 
    133       input_string = constant_op.constant(["brain", "salad", "tank"])
    134       output1 = table1.lookup(input_string)
    135       output2 = table2.lookup(input_string)
    136       output3 = table3.lookup(input_string)
    137 
    138       out1, out2, out3 = sess.run([output1, output2, output3])
    139       self.assertAllEqual([0, 1, -1], out1)
    140       self.assertAllEqual([0, 1, -1], out2)
    141       self.assertAllEqual([0, 1, -1], out3)
    142 
    143   def testHashTableWithTensorDefault(self):
    144     with self.test_session():
    145       default_val = constant_op.constant(-1, dtypes.int64)
    146       keys = constant_op.constant(["brain", "salad", "surgery"])
    147       values = constant_op.constant([0, 1, 2], dtypes.int64)
    148       table = lookup.HashTable(
    149           lookup.KeyValueTensorInitializer(keys, values), default_val)
    150       table.init.run()
    151 
    152       input_string = constant_op.constant(["brain", "salad", "tank"])
    153       output = table.lookup(input_string)
    154 
    155       result = output.eval()
    156       self.assertAllEqual([0, 1, -1], result)
    157 
    158   def testHashTableWithSparseTensorInput(self):
    159     with self.test_session() as sess:
    160       default_val = constant_op.constant(-1, dtypes.int64)
    161       keys = constant_op.constant(["brain", "salad", "surgery"])
    162       values = constant_op.constant([0, 1, 2], dtypes.int64)
    163       table = lookup.HashTable(
    164           lookup.KeyValueTensorInitializer(keys, values), default_val)
    165       table.init.run()
    166 
    167       sp_indices = [[0, 0], [0, 1], [1, 0]]
    168       sp_shape = [2, 2]
    169       input_tensor = sparse_tensor.SparseTensor(
    170           constant_op.constant(sp_indices, dtypes.int64),
    171           constant_op.constant(["brain", "salad", "tank"]),
    172           constant_op.constant(sp_shape, dtypes.int64))
    173       output = table.lookup(input_tensor)
    174 
    175       out_indices, out_values, out_shape = sess.run(output)
    176 
    177       self.assertAllEqual([0, 1, -1], out_values)
    178       self.assertAllEqual(sp_indices, out_indices)
    179       self.assertAllEqual(sp_shape, out_shape)
    180 
    181   def testSignatureMismatch(self):
    182     with self.test_session():
    183       default_val = -1
    184       keys = constant_op.constant(["brain", "salad", "surgery"])
    185       values = constant_op.constant([0, 1, 2], dtypes.int64)
    186       table = lookup.HashTable(
    187           lookup.KeyValueTensorInitializer(keys, values), default_val)
    188       table.init.run()
    189 
    190       # Ref types do not produce a lookup signature mismatch.
    191       input_string_ref = variables.Variable("brain")
    192       variables.global_variables_initializer().run()
    193       self.assertEqual(0, table.lookup(input_string_ref).eval())
    194 
    195       input_string = constant_op.constant([1, 2, 3], dtypes.int64)
    196       with self.assertRaises(TypeError):
    197         table.lookup(input_string)
    198 
    199       with self.assertRaises(TypeError):
    200         lookup.HashTable(
    201             lookup.KeyValueTensorInitializer(keys, values), "UNK")
    202 
    203   def testDTypes(self):
    204     with self.test_session():
    205       default_val = -1
    206       with self.assertRaises(TypeError):
    207         lookup.HashTable(
    208             lookup.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
    209                                              dtypes.int64), default_val)
    210 
    211   def testNotInitialized(self):
    212     with self.test_session():
    213       default_val = -1
    214       table = lookup.HashTable(
    215           lookup.KeyValueTensorInitializer(
    216               ["a"], [1], value_dtype=dtypes.int64),
    217           default_val)
    218 
    219       input_string = constant_op.constant(["brain", "salad", "surgery"])
    220       output = table.lookup(input_string)
    221 
    222       with self.assertRaisesOpError("Table not initialized"):
    223         output.eval()
    224 
    225   def testInitializeTwice(self):
    226     with self.test_session():
    227       default_val = -1
    228       keys = constant_op.constant(["brain", "salad", "surgery"])
    229       values = constant_op.constant([0, 1, 2], dtypes.int64)
    230       table = lookup.HashTable(
    231           lookup.KeyValueTensorInitializer(keys, values), default_val)
    232       table.init.run()
    233 
    234       with self.assertRaisesOpError("Table already initialized"):
    235         table.init.run()
    236 
    237   def testInitializationWithInvalidDimensions(self):
    238     with self.test_session():
    239       default_val = -1
    240       keys = constant_op.constant(["brain", "salad", "surgery"])
    241       values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
    242 
    243       with self.assertRaises(ValueError):
    244         lookup.HashTable(
    245             lookup.KeyValueTensorInitializer(keys, values), default_val)
    246 
    247   def testMultipleSessions(self):
    248     # Start a server
    249     server = server_lib.Server(
    250         {
    251             "local0": ["localhost:0"]
    252         }, protocol="grpc", start=True)
    253     # Create two sessions sharing the same state
    254     session1 = session.Session(server.target)
    255     session2 = session.Session(server.target)
    256 
    257     default_val = -1
    258     keys = constant_op.constant(["brain", "salad", "surgery"])
    259     values = constant_op.constant([0, 1, 2], dtypes.int64)
    260     table = lookup.HashTable(
    261         lookup.KeyValueTensorInitializer(keys, values),
    262         default_val,
    263         name="t1")
    264 
    265     # Init the table in the first session.
    266     with session1:
    267       table.init.run()
    268       self.assertAllEqual(3, table.size().eval())
    269 
    270     # Init the table in the second session and verify that we do not get a
    271     # "Table already initialized" error.
    272     with session2:
    273       table.init.run()
    274       self.assertAllEqual(3, table.size().eval())
    275 
    276 
    277 class MutableHashTableOpTest(test.TestCase):
    278 
    279   def testMutableHashTable(self):
    280     with self.test_session():
    281       default_val = -1
    282       keys = constant_op.constant(["brain", "salad", "surgery"])
    283       values = constant_op.constant([0, 1, 2], dtypes.int64)
    284       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    285                                       default_val)
    286       self.assertAllEqual(0, table.size().eval())
    287 
    288       table.insert(keys, values).run()
    289       self.assertAllEqual(3, table.size().eval())
    290 
    291       input_string = constant_op.constant(["brain", "salad", "tank"])
    292       output = table.lookup(input_string)
    293       self.assertAllEqual([3], output.get_shape())
    294 
    295       result = output.eval()
    296       self.assertAllEqual([0, 1, -1], result)
    297 
    298       exported_keys, exported_values = table.export()
    299       self.assertAllEqual([None], exported_keys.get_shape().as_list())
    300       self.assertAllEqual([None], exported_values.get_shape().as_list())
    301 
    302       # exported data is in the order of the internal map, i.e. undefined
    303       sorted_keys = np.sort(exported_keys.eval())
    304       sorted_values = np.sort(exported_values.eval())
    305       self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
    306       self.assertAllEqual([0, 1, 2], sorted_values)
    307 
    308   def testSaveRestore(self):
    309     save_dir = os.path.join(self.get_temp_dir(), "save_restore")
    310     save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
    311 
    312     with self.test_session(graph=ops.Graph()) as sess:
    313       v0 = variables.Variable(10.0, name="v0")
    314       v1 = variables.Variable(20.0, name="v1")
    315 
    316       default_val = -1
    317       keys = constant_op.constant(["b", "c", "d"], dtypes.string)
    318       values = constant_op.constant([0, 1, 2], dtypes.int64)
    319       table = lookup.MutableHashTable(
    320           dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
    321 
    322       save = saver.Saver()
    323       variables.global_variables_initializer().run()
    324 
    325       # Check that the parameter nodes have been initialized.
    326       self.assertEqual(10.0, v0.eval())
    327       self.assertEqual(20.0, v1.eval())
    328 
    329       self.assertAllEqual(0, table.size().eval())
    330       table.insert(keys, values).run()
    331       self.assertAllEqual(3, table.size().eval())
    332 
    333       val = save.save(sess, save_path)
    334       self.assertTrue(isinstance(val, six.string_types))
    335       self.assertEqual(save_path, val)
    336 
    337     with self.test_session(graph=ops.Graph()) as sess:
    338       v0 = variables.Variable(-1.0, name="v0")
    339       v1 = variables.Variable(-1.0, name="v1")
    340       default_val = -1
    341       table = lookup.MutableHashTable(
    342           dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
    343       table.insert(
    344           constant_op.constant(["a", "c"], dtypes.string),
    345           constant_op.constant([12, 24], dtypes.int64)).run()
    346       self.assertAllEqual(2, table.size().eval())
    347 
    348       save = saver.Saver()
    349 
    350       # Restore the saved values in the parameter nodes.
    351       save.restore(sess, save_path)
    352       # Check that the parameter nodes have been restored.
    353       self.assertEqual(10.0, v0.eval())
    354       self.assertEqual(20.0, v1.eval())
    355 
    356       self.assertAllEqual(3, table.size().eval())
    357 
    358       input_string = constant_op.constant(["a", "b", "c", "d", "e"],
    359                                           dtypes.string)
    360       output = table.lookup(input_string)
    361       self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
    362 
    363   def testSharing(self):
    364     # Start a server to store the table state
    365     server = server_lib.Server(
    366         {
    367             "local0": ["localhost:0"]
    368         }, protocol="grpc", start=True)
    369     # Create two sessions sharing the same state
    370     session1 = session.Session(server.target)
    371     session2 = session.Session(server.target)
    372 
    373     table = lookup.MutableHashTable(
    374         dtypes.int64, dtypes.string, "-", name="t1")
    375 
    376     # Populate the table in the first session
    377     with session1:
    378       self.assertAllEqual(0, table.size().eval())
    379 
    380       keys = constant_op.constant([11, 12], dtypes.int64)
    381       values = constant_op.constant(["a", "b"])
    382       table.insert(keys, values).run()
    383       self.assertAllEqual(2, table.size().eval())
    384 
    385       output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64))
    386       self.assertAllEqual([b"a", b"b", b"-"], output.eval())
    387 
    388     # Verify that we can access the shared data from the second session
    389     with session2:
    390       self.assertAllEqual(2, table.size().eval())
    391 
    392       output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64))
    393       self.assertAllEqual([b"-", b"a", b"b"], output.eval())
    394 
    395   def testMutableHashTableOfTensors(self):
    396     with self.test_session():
    397       default_val = constant_op.constant([-1, -1], dtypes.int64)
    398       keys = constant_op.constant(["brain", "salad", "surgery"])
    399       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
    400       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    401                                       default_val)
    402       self.assertAllEqual(0, table.size().eval())
    403 
    404       table.insert(keys, values).run()
    405       self.assertAllEqual(3, table.size().eval())
    406 
    407       input_string = constant_op.constant(["brain", "salad", "tank"])
    408       output = table.lookup(input_string)
    409       self.assertAllEqual([3, 2], output.get_shape())
    410 
    411       result = output.eval()
    412       self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
    413 
    414       exported_keys, exported_values = table.export()
    415       self.assertAllEqual([None], exported_keys.get_shape().as_list())
    416       self.assertAllEqual([None, 2], exported_values.get_shape().as_list())
    417       # exported data is in the order of the internal map, i.e. undefined
    418       sorted_keys = np.sort(exported_keys.eval())
    419       sorted_values = np.sort(exported_values.eval())
    420       self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
    421       self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
    422 
    423   def testMutableHashTableExportInsert(self):
    424     with self.test_session():
    425       default_val = constant_op.constant([-1, -1], dtypes.int64)
    426       keys = constant_op.constant(["brain", "salad", "surgery"])
    427       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
    428       table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    429                                        default_val)
    430       self.assertAllEqual(0, table1.size().eval())
    431       table1.insert(keys, values).run()
    432       self.assertAllEqual(3, table1.size().eval())
    433 
    434       input_string = constant_op.constant(["brain", "salad", "tank"])
    435       expected_output = [[0, 1], [2, 3], [-1, -1]]
    436       output1 = table1.lookup(input_string)
    437       self.assertAllEqual(expected_output, output1.eval())
    438 
    439       exported_keys, exported_values = table1.export()
    440       self.assertAllEqual(3, exported_keys.eval().size)
    441       self.assertAllEqual(6, exported_values.eval().size)
    442 
    443       # Populate a second table from the exported data
    444       table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    445                                        default_val)
    446       self.assertAllEqual(0, table2.size().eval())
    447       table2.insert(exported_keys, exported_values).run()
    448       self.assertAllEqual(3, table2.size().eval())
    449 
    450       # Verify lookup result is still the same
    451       output2 = table2.lookup(input_string)
    452       self.assertAllEqual(expected_output, output2.eval())
    453 
    454   def testMutableHashTableOfTensorsInvalidShape(self):
    455     with self.test_session():
    456       default_val = constant_op.constant([-1, -1], dtypes.int64)
    457       keys = constant_op.constant(["brain", "salad", "surgery"])
    458       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    459                                       default_val)
    460 
    461       # Shape [6] instead of [3, 2]
    462       values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64)
    463       with self.assertRaisesOpError("Expected shape"):
    464         table.insert(keys, values).run()
    465 
    466       # Shape [2,3] instead of [3, 2]
    467       values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64)
    468       with self.assertRaisesOpError("Expected shape"):
    469         table.insert(keys, values).run()
    470 
    471       # Shape [2, 2] instead of [3, 2]
    472       values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
    473       with self.assertRaisesOpError("Expected shape"):
    474         table.insert(keys, values).run()
    475 
    476       # Shape [3, 1] instead of [3, 2]
    477       values = constant_op.constant([[0], [2], [4]], dtypes.int64)
    478       with self.assertRaisesOpError("Expected shape"):
    479         table.insert(keys, values).run()
    480 
    481       # Valid Insert
    482       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
    483       table.insert(keys, values).run()
    484       self.assertAllEqual(3, table.size().eval())
    485 
    486   def testMutableHashTableInvalidDefaultValue(self):
    487     with self.test_session():
    488       default_val = constant_op.constant([[-1, -1]], dtypes.int64)
    489       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    490                                       default_val)
    491       with self.assertRaisesOpError("Default value must be a vector"):
    492         self.assertAllEqual(0, table.size().eval())
    493 
    494   def testMutableHashTableDuplicateInsert(self):
    495     with self.test_session():
    496       default_val = -1
    497       keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
    498       values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
    499       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    500                                       default_val)
    501       self.assertAllEqual(0, table.size().eval())
    502 
    503       table.insert(keys, values).run()
    504       self.assertAllEqual(3, table.size().eval())
    505 
    506       input_string = constant_op.constant(["brain", "salad", "tank"])
    507       output = table.lookup(input_string)
    508 
    509       result = output.eval()
    510       self.assertAllEqual([3, 1, -1], result)
    511 
    512   def testMutableHashTableFindHighRank(self):
    513     with self.test_session():
    514       default_val = -1
    515       keys = constant_op.constant(["brain", "salad", "surgery"])
    516       values = constant_op.constant([0, 1, 2], dtypes.int64)
    517       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    518                                       default_val)
    519 
    520       table.insert(keys, values).run()
    521       self.assertAllEqual(3, table.size().eval())
    522 
    523       input_string = constant_op.constant(
    524           [["brain", "salad"], ["tank", "tarkus"]])
    525       output = table.lookup(input_string)
    526       self.assertAllEqual([2, 2], output.get_shape())
    527 
    528       result = output.eval()
    529       self.assertAllEqual([[0, 1], [-1, -1]], result)
    530 
    531   def testMutableHashTableInsertHighRank(self):
    532     with self.test_session():
    533       default_val = -1
    534       keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
    535       values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
    536       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    537                                       default_val)
    538 
    539       table.insert(keys, values).run()
    540       self.assertAllEqual(4, table.size().eval())
    541 
    542       input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
    543       output = table.lookup(input_string)
    544 
    545       result = output.eval()
    546       self.assertAllEqual([0, 1, 3, -1], result)
    547 
    548   def testMutableHashTableOfTensorsFindHighRank(self):
    549     with self.test_session():
    550       default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
    551       keys = constant_op.constant(["brain", "salad", "surgery"])
    552       values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
    553                                     dtypes.int64)
    554       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    555                                       default_val)
    556 
    557       table.insert(keys, values).run()
    558       self.assertAllEqual(3, table.size().eval())
    559 
    560       input_string = constant_op.constant(
    561           [["brain", "salad"], ["tank", "tarkus"]])
    562       output = table.lookup(input_string)
    563       self.assertAllEqual([2, 2, 3], output.get_shape())
    564 
    565       result = output.eval()
    566       self.assertAllEqual(
    567           [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
    568 
    569   def testMultipleMutableHashTables(self):
    570     with self.test_session() as sess:
    571       default_val = -1
    572       keys = constant_op.constant(["brain", "salad", "surgery"])
    573       values = constant_op.constant([0, 1, 2], dtypes.int64)
    574 
    575       table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    576                                        default_val)
    577       table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    578                                        default_val)
    579       table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    580                                        default_val)
    581       table1.insert(keys, values).run()
    582       table2.insert(keys, values).run()
    583       table3.insert(keys, values).run()
    584 
    585       self.assertAllEqual(3, table1.size().eval())
    586       self.assertAllEqual(3, table2.size().eval())
    587       self.assertAllEqual(3, table3.size().eval())
    588 
    589       input_string = constant_op.constant(["brain", "salad", "tank"])
    590       output1 = table1.lookup(input_string)
    591       output2 = table2.lookup(input_string)
    592       output3 = table3.lookup(input_string)
    593 
    594       out1, out2, out3 = sess.run([output1, output2, output3])
    595       self.assertAllEqual([0, 1, -1], out1)
    596       self.assertAllEqual([0, 1, -1], out2)
    597       self.assertAllEqual([0, 1, -1], out3)
    598 
    599   def testMutableHashTableWithTensorDefault(self):
    600     with self.test_session():
    601       default_val = constant_op.constant(-1, dtypes.int64)
    602       keys = constant_op.constant(["brain", "salad", "surgery"])
    603       values = constant_op.constant([0, 1, 2], dtypes.int64)
    604       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    605                                       default_val)
    606 
    607       table.insert(keys, values).run()
    608       self.assertAllEqual(3, table.size().eval())
    609 
    610       input_string = constant_op.constant(["brain", "salad", "tank"])
    611       output = table.lookup(input_string)
    612 
    613       result = output.eval()
    614       self.assertAllEqual([0, 1, -1], result)
    615 
    616   def testSignatureMismatch(self):
    617     with self.test_session():
    618       default_val = -1
    619       keys = constant_op.constant(["brain", "salad", "surgery"])
    620       values = constant_op.constant([0, 1, 2], dtypes.int64)
    621       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
    622                                       default_val)
    623 
    624       # insert with keys of the wrong type
    625       with self.assertRaises(TypeError):
    626         table.insert(constant_op.constant([4, 5, 6]), values).run()
    627 
    628       # insert with values of the wrong type
    629       with self.assertRaises(TypeError):
    630         table.insert(keys, constant_op.constant(["a", "b", "c"])).run()
    631 
    632       self.assertAllEqual(0, table.size().eval())
    633 
    634       table.insert(keys, values).run()
    635       self.assertAllEqual(3, table.size().eval())
    636 
    637       input_string_ref = variables.Variable("brain")
    638       input_int64_ref = variables.Variable(-1, dtype=dtypes.int64)
    639       variables.global_variables_initializer().run()
    640 
    641       # Ref types do not produce an insert signature mismatch.
    642       table.insert(input_string_ref, input_int64_ref).run()
    643       self.assertAllEqual(3, table.size().eval())
    644 
    645       # Ref types do not produce a lookup signature mismatch.
    646       self.assertEqual(-1, table.lookup(input_string_ref).eval())
    647 
    648       # lookup with keys of the wrong type
    649       input_string = constant_op.constant([1, 2, 3], dtypes.int64)
    650       with self.assertRaises(TypeError):
    651         table.lookup(input_string).eval()
    652 
    653       # default value of the wrong type
    654       with self.assertRaises(TypeError):
    655         lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
    656 
    657   def testMutableHashTableStringFloat(self):
    658     with self.test_session():
    659       default_val = -1.5
    660       keys = constant_op.constant(["brain", "salad", "surgery"])
    661       values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
    662       table = lookup.MutableHashTable(dtypes.string, dtypes.float32,
    663                                       default_val)
    664       self.assertAllEqual(0, table.size().eval())
    665 
    666       table.insert(keys, values).run()
    667       self.assertAllEqual(3, table.size().eval())
    668 
    669       input_string = constant_op.constant(["brain", "salad", "tank"])
    670       output = table.lookup(input_string)
    671 
    672       result = output.eval()
    673       self.assertAllClose([0, 1.1, default_val], result)
    674 
    675   def testMutableHashTableIntFloat(self):
    676     with self.test_session():
    677       default_val = -1.0
    678       keys = constant_op.constant([3, 7, 0], dtypes.int64)
    679       values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
    680       table = lookup.MutableHashTable(dtypes.int64, dtypes.float32,
    681                                       default_val)
    682       self.assertAllEqual(0, table.size().eval())
    683 
    684       table.insert(keys, values).run()
    685       self.assertAllEqual(3, table.size().eval())
    686 
    687       input_string = constant_op.constant([7, 0, 11], dtypes.int64)
    688       output = table.lookup(input_string)
    689 
    690       result = output.eval()
    691       self.assertAllClose([-1.2, 9.9, default_val], result)
    692 
    693   def testMutableHashTableInt64String(self):
    694     with self.test_session():
    695       default_val = "n/a"
    696       keys = constant_op.constant([0, 1, 2], dtypes.int64)
    697       values = constant_op.constant(["brain", "salad", "surgery"])
    698       table = lookup.MutableHashTable(dtypes.int64, dtypes.string,
    699                                       default_val)
    700       self.assertAllEqual(0, table.size().eval())
    701 
    702       table.insert(keys, values).run()
    703       self.assertAllEqual(3, table.size().eval())
    704 
    705       input_string = constant_op.constant([0, 1, 3], dtypes.int64)
    706       output = table.lookup(input_string)
    707 
    708       result = output.eval()
    709       self.assertAllEqual((b"brain", b"salad", b"n/a"), result)
    710 
    711 
    712 class MutableDenseHashTableOpTest(test.TestCase):
    713 
    714   def testBasic(self):
    715     with self.test_session():
    716       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    717       values = constant_op.constant([0, 1, 2], dtypes.int64)
    718       table = lookup.MutableDenseHashTable(
    719           dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
    720       self.assertAllEqual(0, table.size().eval())
    721 
    722       table.insert(keys, values).run()
    723       self.assertAllEqual(3, table.size().eval())
    724 
    725       input_string = constant_op.constant([11, 12, 15], dtypes.int64)
    726       output = table.lookup(input_string)
    727       self.assertAllEqual([3], output.get_shape())
    728 
    729       result = output.eval()
    730       self.assertAllEqual([0, 1, -1], result)
    731 
    732   def testBasicBool(self):
    733     with self.test_session():
    734       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    735       values = constant_op.constant([True, True, True], dtypes.bool)
    736       table = lookup.MutableDenseHashTable(
    737           dtypes.int64, dtypes.bool, default_value=False, empty_key=0)
    738       self.assertAllEqual(0, table.size().eval())
    739 
    740       table.insert(keys, values).run()
    741       self.assertAllEqual(3, table.size().eval())
    742 
    743       input_string = constant_op.constant([11, 12, 15], dtypes.int64)
    744       output = table.lookup(input_string)
    745       self.assertAllEqual([3], output.get_shape())
    746 
    747       result = output.eval()
    748       self.assertAllEqual([True, True, False], result)
    749 
    750   def testLookupUnknownShape(self):
    751     with self.test_session():
    752       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    753       values = constant_op.constant([0, 1, 2], dtypes.int64)
    754       table = lookup.MutableDenseHashTable(
    755           dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
    756 
    757       table.insert(keys, values).run()
    758       self.assertAllEqual(3, table.size().eval())
    759 
    760       placeholder_keys = array_ops.placeholder(dtypes.int64)
    761       output = table.lookup(placeholder_keys)
    762       self.assertAllEqual(None, output.get_shape())
    763       result = output.eval({placeholder_keys: [11, 12, 15]})
    764       self.assertAllEqual([0, 1, -1], result)
    765 
    766   def testMapStringToFloat(self):
    767     with self.test_session():
    768       keys = constant_op.constant(["a", "b", "c"], dtypes.string)
    769       values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
    770       default_value = constant_op.constant(-1.5, dtypes.float32)
    771       table = lookup.MutableDenseHashTable(
    772           dtypes.string,
    773           dtypes.float32,
    774           default_value=default_value,
    775           empty_key="")
    776       self.assertAllEqual(0, table.size().eval())
    777 
    778       table.insert(keys, values).run()
    779       self.assertAllEqual(3, table.size().eval())
    780 
    781       input_string = constant_op.constant(["a", "b", "d"], dtypes.string)
    782       output = table.lookup(input_string)
    783       self.assertAllEqual([3], output.get_shape())
    784 
    785       result = output.eval()
    786       self.assertAllClose([0, 1.1, -1.5], result)
    787 
    788   def testMapInt64ToFloat(self):
    789     for float_dtype in [dtypes.float32, dtypes.float64]:
    790       with self.test_session():
    791         keys = constant_op.constant([11, 12, 13], dtypes.int64)
    792         values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
    793         default_value = constant_op.constant(-1.5, float_dtype)
    794         table = lookup.MutableDenseHashTable(
    795             dtypes.int64, float_dtype, default_value=default_value, empty_key=0)
    796         self.assertAllEqual(0, table.size().eval())
    797 
    798         table.insert(keys, values).run()
    799         self.assertAllEqual(3, table.size().eval())
    800 
    801         input_string = constant_op.constant([11, 12, 15], dtypes.int64)
    802         output = table.lookup(input_string)
    803         self.assertAllEqual([3], output.get_shape())
    804 
    805         result = output.eval()
    806         self.assertAllClose([0, 1.1, -1.5], result)
    807 
    808   def testVectorValues(self):
    809     with self.test_session():
    810       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    811       values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
    812                                     dtypes.int64)
    813       default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64)
    814       table = lookup.MutableDenseHashTable(
    815           dtypes.int64,
    816           dtypes.int64,
    817           default_value=default_value,
    818           empty_key=0,
    819           initial_num_buckets=4)
    820       self.assertAllEqual(0, table.size().eval())
    821 
    822       table.insert(keys, values).run()
    823       self.assertAllEqual(3, table.size().eval())
    824       self.assertAllEqual(4, len(table.export()[0].eval()))
    825 
    826       table.insert(
    827           constant_op.constant([14], dtypes.int64),
    828           constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run()
    829       self.assertAllEqual(4, table.size().eval())
    830       self.assertAllEqual(8, len(table.export()[0].eval()))
    831 
    832       input_string = constant_op.constant([11, 12, 15], dtypes.int64)
    833       output = table.lookup(input_string)
    834       self.assertAllEqual([3, 4], output.get_shape())
    835 
    836       result = output.eval()
    837       self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],
    838                           result)
    839 
    840   def testVectorKeys(self):
    841     with self.test_session():
    842       keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
    843       values = constant_op.constant([10, 11, 12], dtypes.int64)
    844       empty_key = constant_op.constant([0, 3], dtypes.int64)
    845       default_value = constant_op.constant(-1, dtypes.int64)
    846       table = lookup.MutableDenseHashTable(
    847           dtypes.int64,
    848           dtypes.int64,
    849           default_value=default_value,
    850           empty_key=empty_key,
    851           initial_num_buckets=8)
    852       self.assertAllEqual(0, table.size().eval())
    853 
    854       table.insert(keys, values).run()
    855       self.assertAllEqual(3, table.size().eval())
    856 
    857       table.insert(
    858           constant_op.constant([[0, 0]], dtypes.int64),
    859           constant_op.constant([13], dtypes.int64)).run()
    860       self.assertAllEqual(4, table.size().eval())
    861       self.assertAllEqual(8, len(table.export()[0].eval()))
    862 
    863       input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]],
    864                                           dtypes.int64)
    865       output = table.lookup(input_string)
    866       self.assertAllEqual([3], output.get_shape())
    867 
    868       result = output.eval()
    869       self.assertAllEqual([10, 11, -1], result)
    870 
    871   def testResize(self):
    872     with self.test_session():
    873       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    874       values = constant_op.constant([0, 1, 2], dtypes.int64)
    875       table = lookup.MutableDenseHashTable(
    876           dtypes.int64,
    877           dtypes.int64,
    878           default_value=-1,
    879           empty_key=0,
    880           initial_num_buckets=4)
    881       self.assertAllEqual(0, table.size().eval())
    882 
    883       table.insert(keys, values).run()
    884       self.assertAllEqual(3, table.size().eval())
    885       self.assertAllEqual(4, len(table.export()[0].eval()))
    886 
    887       keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
    888       values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
    889 
    890       table.insert(keys2, values2).run()
    891       self.assertAllEqual(7, table.size().eval())
    892       self.assertAllEqual(16, len(table.export()[0].eval()))
    893 
    894       keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
    895                                    dtypes.int64)
    896       output = table.lookup(keys3)
    897       self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
    898 
    899   def testExport(self):
    900     with self.test_session():
    901       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    902       values = constant_op.constant([1, 2, 3], dtypes.int64)
    903       table = lookup.MutableDenseHashTable(
    904           dtypes.int64,
    905           dtypes.int64,
    906           default_value=-1,
    907           empty_key=100,
    908           initial_num_buckets=8)
    909       self.assertAllEqual(0, table.size().eval())
    910 
    911       table.insert(keys, values).run()
    912       self.assertAllEqual(3, table.size().eval())
    913 
    914       exported_keys, exported_values = table.export()
    915       self.assertAllEqual([None], exported_keys.get_shape().as_list())
    916       self.assertAllEqual([None], exported_values.get_shape().as_list())
    917 
    918       np_keys = exported_keys.eval()
    919       np_values = exported_values.eval()
    920 
    921       self.assertAllEqual(8, len(np_keys))
    922       self.assertAllEqual(8, len(np_values))
    923 
    924       # pair up keys and values, drop extra added dimension
    925       pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
    926       # sort by key
    927       pairs = pairs[pairs[:, 0].argsort()]
    928       self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0],
    929                            [100, 0], [100, 0], [100, 0]], pairs)
    930 
    931   def testSaveRestore(self):
    932     save_dir = os.path.join(self.get_temp_dir(), "save_restore")
    933     save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
    934 
    935     with self.test_session(graph=ops.Graph()) as sess:
    936       default_value = -1
    937       empty_key = 0
    938       keys = constant_op.constant([11, 12, 13], dtypes.int64)
    939       values = constant_op.constant([0, 1, 2], dtypes.int64)
    940       table = lookup.MutableDenseHashTable(
    941           dtypes.int64,
    942           dtypes.int64,
    943           default_value=default_value,
    944           empty_key=empty_key,
    945           name="t1",
    946           checkpoint=True,
    947           initial_num_buckets=32)
    948 
    949       save = saver.Saver()
    950 
    951       self.assertAllEqual(0, table.size().eval())
    952       table.insert(keys, values).run()
    953       self.assertAllEqual(3, table.size().eval())
    954       self.assertAllEqual(32, len(table.export()[0].eval()))
    955 
    956       val = save.save(sess, save_path)
    957       self.assertTrue(isinstance(val, six.string_types))
    958       self.assertEqual(save_path, val)
    959 
    960     with self.test_session(graph=ops.Graph()) as sess:
    961       table = lookup.MutableDenseHashTable(
    962           dtypes.int64,
    963           dtypes.int64,
    964           default_value=default_value,
    965           empty_key=empty_key,
    966           name="t1",
    967           checkpoint=True,
    968           initial_num_buckets=64)
    969       table.insert(
    970           constant_op.constant([11, 14], dtypes.int64),
    971           constant_op.constant([12, 24], dtypes.int64)).run()
    972       self.assertAllEqual(2, table.size().eval())
    973       self.assertAllEqual(64, len(table.export()[0].eval()))
    974 
    975       save = saver.Saver()
    976 
    977       # Restore the saved values in the parameter nodes.
    978       save.restore(sess, save_path)
    979 
    980       self.assertAllEqual(3, table.size().eval())
    981       self.assertAllEqual(32, len(table.export()[0].eval()))
    982 
    983       input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
    984       output = table.lookup(input_string)
    985       self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
    986 
    987   def testVectorSaveRestore(self):
    988     save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
    989     save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
    990 
    991     with self.test_session(graph=ops.Graph()) as sess:
    992       empty_key = constant_op.constant([11, 13], dtypes.int64)
    993       default_value = constant_op.constant([-1, -2], dtypes.int64)
    994       keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
    995       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
    996       table = lookup.MutableDenseHashTable(
    997           dtypes.int64,
    998           dtypes.int64,
    999           default_value=default_value,
   1000           empty_key=empty_key,
   1001           name="t1",
   1002           checkpoint=True,
   1003           initial_num_buckets=32)
   1004 
   1005       save = saver.Saver()
   1006 
   1007       self.assertAllEqual(0, table.size().eval())
   1008       table.insert(keys, values).run()
   1009       self.assertAllEqual(3, table.size().eval())
   1010       self.assertAllEqual(32, len(table.export()[0].eval()))
   1011 
   1012       val = save.save(sess, save_path)
   1013       self.assertTrue(isinstance(val, six.string_types))
   1014       self.assertEqual(save_path, val)
   1015 
   1016     with self.test_session(graph=ops.Graph()) as sess:
   1017       empty_key = constant_op.constant([11, 13], dtypes.int64)
   1018       default_value = constant_op.constant([-1, -2], dtypes.int64)
   1019       table = lookup.MutableDenseHashTable(
   1020           dtypes.int64,
   1021           dtypes.int64,
   1022           default_value=default_value,
   1023           empty_key=empty_key,
   1024           name="t1",
   1025           checkpoint=True,
   1026           initial_num_buckets=64)
   1027       table.insert(
   1028           constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
   1029           constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run()
   1030       self.assertAllEqual(2, table.size().eval())
   1031       self.assertAllEqual(64, len(table.export()[0].eval()))
   1032 
   1033       save = saver.Saver()
   1034 
   1035       # Restore the saved values in the parameter nodes.
   1036       save.restore(sess, save_path)
   1037 
   1038       self.assertAllEqual(3, table.size().eval())
   1039       self.assertAllEqual(32, len(table.export()[0].eval()))
   1040 
   1041       input_string = constant_op.constant(
   1042           [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
   1043       output = table.lookup(input_string)
   1044       self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
   1045                           output.eval())
   1046 
   1047   def testVectorScalarSaveRestore(self):
   1048     save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
   1049     save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
   1050 
   1051     with self.test_session(graph=ops.Graph()) as sess:
   1052       empty_key = constant_op.constant([11, 13], dtypes.int64)
   1053       default_value = constant_op.constant(-1, dtypes.int64)
   1054       keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
   1055       values = constant_op.constant([0, 1, 2], dtypes.int64)
   1056       table = lookup.MutableDenseHashTable(
   1057           dtypes.int64,
   1058           dtypes.int64,
   1059           default_value=default_value,
   1060           empty_key=empty_key,
   1061           name="t2",
   1062           checkpoint=True,
   1063           initial_num_buckets=32)
   1064 
   1065       save = saver.Saver()
   1066 
   1067       self.assertAllEqual(0, table.size().eval())
   1068       table.insert(keys, values).run()
   1069       self.assertAllEqual(3, table.size().eval())
   1070       self.assertAllEqual(32, len(table.export()[0].eval()))
   1071 
   1072       val = save.save(sess, save_path)
   1073       self.assertTrue(isinstance(val, six.string_types))
   1074       self.assertEqual(save_path, val)
   1075 
   1076     with self.test_session(graph=ops.Graph()) as sess:
   1077       empty_key = constant_op.constant([11, 13], dtypes.int64)
   1078       default_value = constant_op.constant(-1, dtypes.int64)
   1079       table = lookup.MutableDenseHashTable(
   1080           dtypes.int64,
   1081           dtypes.int64,
   1082           default_value=default_value,
   1083           empty_key=empty_key,
   1084           name="t2",
   1085           checkpoint=True,
   1086           initial_num_buckets=64)
   1087       table.insert(
   1088           constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
   1089           constant_op.constant([3, 4], dtypes.int64)).run()
   1090       self.assertAllEqual(2, table.size().eval())
   1091       self.assertAllEqual(64, len(table.export()[0].eval()))
   1092 
   1093       save = saver.Saver()
   1094 
   1095       # Restore the saved values in the parameter nodes.
   1096       save.restore(sess, save_path)
   1097 
   1098       self.assertAllEqual(3, table.size().eval())
   1099       self.assertAllEqual(32, len(table.export()[0].eval()))
   1100 
   1101       input_string = constant_op.constant(
   1102           [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
   1103       output = table.lookup(input_string)
   1104       self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
   1105 
   1106   def testReprobe(self):
   1107     with self.test_session():
   1108       # Insert 6 keys into a table with 8 buckets.
   1109       # The values are chosen to make sure collisions occur when using GCC STL
   1110       keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
   1111       values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64)
   1112       table = lookup.MutableDenseHashTable(
   1113           dtypes.int64,
   1114           dtypes.int64,
   1115           default_value=-1,
   1116           empty_key=0,
   1117           initial_num_buckets=8)
   1118       self.assertAllEqual(0, table.size().eval())
   1119 
   1120       table.insert(keys, values).run()
   1121       self.assertAllEqual(6, table.size().eval())
   1122 
   1123       input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22],
   1124                                           dtypes.int64)
   1125       output = table.lookup(input_string)
   1126       self.assertAllEqual([9], output.get_shape())
   1127 
   1128       result = output.eval()
   1129       self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
   1130 
   1131   def testCustomEmptyKey(self):
   1132     with self.test_session():
   1133       keys = constant_op.constant([11, 0, 13], dtypes.int64)
   1134       values = constant_op.constant([0, 1, 2], dtypes.int64)
   1135       table = lookup.MutableDenseHashTable(
   1136           dtypes.int64, dtypes.int64, default_value=-1, empty_key=12)
   1137       self.assertAllEqual(0, table.size().eval())
   1138 
   1139       table.insert(keys, values).run()
   1140       self.assertAllEqual(3, table.size().eval())
   1141 
   1142       input_string = constant_op.constant([11, 0, 15], dtypes.int64)
   1143       output = table.lookup(input_string)
   1144       self.assertAllEqual([3], output.get_shape())
   1145 
   1146       result = output.eval()
   1147       self.assertAllEqual([0, 1, -1], result)
   1148 
   1149   def testErrors(self):
   1150     with self.test_session():
   1151       table = lookup.MutableDenseHashTable(
   1152           dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
   1153 
   1154       # Inserting the empty key returns an error
   1155       keys = constant_op.constant([11, 0], dtypes.int64)
   1156       values = constant_op.constant([0, 1], dtypes.int64)
   1157       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1158                                    "empty_key"):
   1159         table.insert(keys, values).run()
   1160 
   1161       # Looking up the empty key returns an error
   1162       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1163                                    "empty_key"):
   1164         table.lookup(keys).eval()
   1165 
   1166       # Arbitrary tensors of keys are not supported
   1167       keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
   1168       values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
   1169       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1170                                    "Expected key shape"):
   1171         table.lookup(keys).eval()
   1172       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1173                                    "Expected key shape"):
   1174         table.insert(keys, values).run()
   1175 
   1176       table2 = lookup.MutableDenseHashTable(
   1177           dtypes.int64,
   1178           dtypes.int64,
   1179           default_value=-1,
   1180           empty_key=17,
   1181           initial_num_buckets=12)
   1182       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1183                                    "Number of buckets must be"):
   1184         self.assertAllEqual(0, table2.size().eval())
   1185 
   1186 
   1187 class IndexTableFromFile(test.TestCase):
   1188 
   1189   def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
   1190     vocabulary_file = os.path.join(self.get_temp_dir(), basename)
   1191     with open(vocabulary_file, "w") as f:
   1192       f.write("\n".join(values) + "\n")
   1193     return vocabulary_file
   1194 
   1195   def test_string_index_table_from_file(self):
   1196     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
   1197     with self.test_session():
   1198       table = lookup.index_table_from_file(
   1199           vocabulary_file=vocabulary_file, num_oov_buckets=1)
   1200       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1201 
   1202       self.assertRaises(errors_impl.OpError, ids.eval)
   1203       lookup_ops.tables_initializer().run()
   1204       self.assertAllEqual((1, 2, 3), ids.eval())
   1205 
   1206   def test_string_index_table_from_file_tensor_filename(self):
   1207     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
   1208     with self.test_session():
   1209       vocabulary_file = constant_op.constant(vocabulary_file)
   1210       table = lookup.index_table_from_file(
   1211           vocabulary_file=vocabulary_file, num_oov_buckets=1)
   1212       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1213 
   1214       self.assertRaises(errors_impl.OpError, ids.eval)
   1215       lookup_ops.tables_initializer().run()
   1216       self.assertAllEqual((1, 2, 3), ids.eval())
   1217       self.assertEqual(1,
   1218                        len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
   1219 
   1220   def test_string_index_table_from_file_placeholder_filename(self):
   1221     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
   1222     with self.test_session():
   1223       vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
   1224       table = lookup.index_table_from_file(
   1225           vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
   1226       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1227 
   1228       self.assertRaises(errors_impl.OpError, ids.eval)
   1229 
   1230       feed_dict = {vocabulary_placeholder.name: vocabulary_file}
   1231       lookup_ops.tables_initializer().run(feed_dict=feed_dict)
   1232       self.assertAllEqual((1, 2, 3), ids.eval())
   1233       self.assertEqual(0,
   1234                        len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
   1235 
   1236   def test_int32_index_table_from_file(self):
   1237     vocabulary_file = self._createVocabFile(
   1238         "f2i_vocab2.txt", values=("42", "1", "-1000"))
   1239     with self.test_session():
   1240       table = lookup.index_table_from_file(
   1241           vocabulary_file=vocabulary_file, num_oov_buckets=1,
   1242           key_dtype=dtypes.int32)
   1243       ids = table.lookup(
   1244           constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
   1245 
   1246       self.assertRaises(errors_impl.OpError, ids.eval)
   1247       lookup_ops.tables_initializer().run()
   1248       self.assertAllEqual((1, 2, 3), ids.eval())
   1249 
   1250   def test_int64_index_table_from_file(self):
   1251     vocabulary_file = self._createVocabFile(
   1252         "f2i_vocab3.txt", values=("42", "1", "-1000"))
   1253     with self.test_session():
   1254       table = lookup.index_table_from_file(
   1255           vocabulary_file=vocabulary_file, num_oov_buckets=1,
   1256           key_dtype=dtypes.int64)
   1257       ids = table.lookup(
   1258           constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
   1259 
   1260       self.assertRaises(errors_impl.OpError, ids.eval)
   1261       lookup_ops.tables_initializer().run()
   1262       self.assertAllEqual((1, 2, 3), ids.eval())
   1263 
   1264   def test_index_table_from_file_with_default_value(self):
   1265     default_value = -42
   1266     vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
   1267     with self.test_session():
   1268       table = lookup.index_table_from_file(
   1269           vocabulary_file=vocabulary_file, default_value=default_value)
   1270       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1271 
   1272       self.assertRaises(errors_impl.OpError, ids.eval)
   1273       lookup_ops.tables_initializer().run()
   1274       self.assertAllEqual((1, 2, default_value), ids.eval())
   1275 
   1276   def test_index_table_from_file_with_oov_buckets(self):
   1277     vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
   1278     with self.test_session():
   1279       table = lookup.index_table_from_file(
   1280           vocabulary_file=vocabulary_file, num_oov_buckets=1000)
   1281       ids = table.lookup(
   1282           constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
   1283 
   1284       self.assertRaises(errors_impl.OpError, ids.eval)
   1285       lookup_ops.tables_initializer().run()
   1286       self.assertAllEqual(
   1287           (
   1288               1,  # From vocabulary file.
   1289               2,  # From vocabulary file.
   1290               867,  # 3 + fingerprint("tarkus") mod 300.
   1291               860),  # 3 + fingerprint("toccata") mod 300.
   1292           ids.eval())
   1293 
   1294   def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
   1295     self.assertRaises(
   1296         ValueError,
   1297         lookup.index_table_from_file,
   1298         vocabulary_file="")
   1299 
   1300   def test_index_table_from_file_fails_with_empty_vocabulary(self):
   1301     self.assertRaises(
   1302         ValueError,
   1303         lookup.index_table_from_file,
   1304         vocabulary_file=None)
   1305 
   1306   def test_index_table_from_file_with_vocab_size_too_small(self):
   1307     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
   1308     with self.test_session():
   1309       table = lookup.index_table_from_file(
   1310           vocabulary_file=vocabulary_file, vocab_size=2)
   1311       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1312 
   1313       self.assertRaises(errors_impl.OpError, ids.eval)
   1314       lookup_ops.tables_initializer().run()
   1315       self.assertAllEqual((1, -1, -1), ids.eval())
   1316       self.assertEqual(2, table.size().eval())
   1317 
   1318   def test_index_table_from_file_with_vocab_size_too_large(self):
   1319     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
   1320     with self.test_session():
   1321       table = lookup.index_table_from_file(
   1322           vocabulary_file=vocabulary_file, vocab_size=4)
   1323       self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1324                               "Invalid vocab_size", table.init.run)
   1325 
   1326   def test_index_table_from_file_with_vocab_size(self):
   1327     vocabulary_file = self._createVocabFile("f2i_vocab8.txt")
   1328 
   1329     self.assertRaises(
   1330         ValueError,
   1331         lookup.index_table_from_file,
   1332         vocabulary_file=vocabulary_file,
   1333         vocab_size=0)
   1334 
   1335     with self.test_session():
   1336       table = lookup.index_table_from_file(
   1337           vocabulary_file=vocabulary_file, vocab_size=3)
   1338       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1339 
   1340       self.assertRaises(errors_impl.OpError, ids.eval)
   1341       lookup_ops.tables_initializer().run()
   1342       self.assertAllEqual((1, 2, -1), ids.eval())
   1343       self.assertEqual(3, table.size().eval())
   1344 
   1345   def test_index_table_from_file_with_invalid_hashers(self):
   1346     vocabulary_file = self._createVocabFile("invalid_hasher.txt")
   1347     with self.test_session():
   1348       with self.assertRaises(TypeError):
   1349         lookup.index_table_from_file(
   1350             vocabulary_file=vocabulary_file,
   1351             vocab_size=3,
   1352             num_oov_buckets=1,
   1353             hasher_spec=1)
   1354 
   1355       table = lookup.index_table_from_file(
   1356           vocabulary_file=vocabulary_file,
   1357           vocab_size=3,
   1358           num_oov_buckets=1,
   1359           hasher_spec=lookup.HasherSpec("my-awesome-hash", None))
   1360 
   1361       self.assertRaises(ValueError, table.lookup,
   1362                         constant_op.constant(["salad", "surgery", "tarkus"]))
   1363 
   1364 
   1365 class KeyValueTensorInitializerTest(test.TestCase):
   1366 
   1367   def test_string(self):
   1368     with ops.Graph().as_default(), self.test_session():
   1369       init = lookup.KeyValueTensorInitializer(
   1370           ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
   1371       table = lookup.HashTable(init, default_value=-1)
   1372       table.init.run()
   1373 
   1374   def test_int64(self):
   1375     with ops.Graph().as_default(), self.test_session():
   1376       init = lookup.KeyValueTensorInitializer(
   1377           (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
   1378       table = lookup.HashTable(init, default_value=-1)
   1379       table.init.run()
   1380 
   1381   def test_int32(self):
   1382     with ops.Graph().as_default(), self.test_session():
   1383       init = lookup.KeyValueTensorInitializer(
   1384           (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
   1385       table = lookup.HashTable(init, default_value=-1)
   1386       with self.assertRaisesRegexp(
   1387           errors_impl.OpError, "No OpKernel was registered"):
   1388         table.init.run()
   1389 
   1390 
   1391 class IndexTableFromTensor(test.TestCase):
   1392 
   1393   def test_index_table_from_tensor_with_tensor_init(self):
   1394     with self.test_session():
   1395       table = lookup.index_table_from_tensor(
   1396           mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
   1397       ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
   1398 
   1399       self.assertRaises(errors_impl.OpError, ids.eval)
   1400       lookup_ops.tables_initializer().run()
   1401       self.assertAllEqual((1, 2, 3), ids.eval())
   1402 
   1403   def test_int32_index_table_from_tensor_with_tensor_init(self):
   1404     with self.test_session():
   1405       table = lookup.index_table_from_tensor(
   1406           mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
   1407       ids = table.lookup(
   1408           constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
   1409 
   1410       self.assertRaises(errors_impl.OpError, ids.eval)
   1411       lookup_ops.tables_initializer().run()
   1412       self.assertAllEqual((1, 2, 3), ids.eval())
   1413 
   1414   def test_int64_index_table_from_tensor_with_tensor_init(self):
   1415     with self.test_session():
   1416       table = lookup.index_table_from_tensor(
   1417           mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
   1418       ids = table.lookup(
   1419           constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
   1420 
   1421       self.assertRaises(errors_impl.OpError, ids.eval)
   1422       lookup_ops.tables_initializer().run()
   1423       self.assertAllEqual((1, 2, 3), ids.eval())
   1424 
   1425   def test_index_table_from_tensor_with_default_value(self):
   1426     default_value = -42
   1427     with self.test_session():
   1428       table = lookup.index_table_from_tensor(
   1429           mapping=["brain", "salad", "surgery"], default_value=default_value)
   1430       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
   1431 
   1432       self.assertRaises(errors_impl.OpError, ids.eval)
   1433       lookup_ops.tables_initializer().run()
   1434       self.assertAllEqual((1, 2, default_value), ids.eval())
   1435 
   1436   def test_index_table_from_tensor_missing_mapping(self):
   1437     with self.test_session():
   1438       with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
   1439         lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
   1440 
   1441   def test_index_table_from_tensor_empty_mapping(self):
   1442     with self.test_session():
   1443       table = lookup.index_table_from_tensor(
   1444           mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
   1445       ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
   1446       self.assertRaises(errors_impl.OpError, ids.eval)
   1447       with self.assertRaisesRegexp(
   1448           errors_impl.OpError, "keys and values cannot be empty"):
   1449         lookup_ops.tables_initializer().run()
   1450 
   1451   def test_index_table_from_tensor_with_invalid_hashers(self):
   1452     with self.test_session():
   1453       with self.assertRaises(TypeError):
   1454         lookup.index_table_from_tensor(
   1455             mapping=["brain", "salad", "surgery"],
   1456             num_oov_buckets=1,
   1457             hasher_spec=1)
   1458 
   1459       table = lookup.index_table_from_tensor(
   1460           mapping=["brain", "salad", "surgery"],
   1461           num_oov_buckets=1,
   1462           hasher_spec=lookup.HasherSpec("my-awesome-hash", None))
   1463 
   1464       self.assertRaises(ValueError, table.lookup,
   1465                         constant_op.constant(["salad", "surgery", "tarkus"]))
   1466 
   1467 
   1468 class StringToIndexTest(test.TestCase):
   1469 
   1470   def test_string_to_index(self):
   1471     with self.test_session():
   1472       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1473       feats = constant_op.constant(["salad", "surgery", "tarkus"])
   1474       indices = lookup.string_to_index(feats, mapping=mapping_strings)
   1475 
   1476       self.assertRaises(errors_impl.OpError, indices.eval)
   1477       lookup_ops.tables_initializer().run()
   1478 
   1479       self.assertAllEqual((1, 2, -1), indices.eval())
   1480 
   1481   def test_duplicate_entries(self):
   1482     with self.test_session():
   1483       mapping_strings = constant_op.constant(["hello", "hello"])
   1484       feats = constant_op.constant(["hello", "hola"])
   1485       _ = lookup.string_to_index(feats, mapping=mapping_strings)
   1486 
   1487       self.assertRaises(errors_impl.OpError,
   1488                         lookup_ops.tables_initializer().run)
   1489 
   1490   def test_string_to_index_with_default_value(self):
   1491     default_value = -42
   1492     with self.test_session():
   1493       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1494       feats = constant_op.constant(["salad", "surgery", "tarkus"])
   1495       indices = lookup.string_to_index(
   1496           feats, mapping=mapping_strings, default_value=default_value)
   1497       self.assertRaises(errors_impl.OpError, indices.eval)
   1498 
   1499       lookup_ops.tables_initializer().run()
   1500       self.assertAllEqual((1, 2, default_value), indices.eval())
   1501 
   1502 
   1503 class IndexToStringTableFromFileTest(test.TestCase):
   1504 
   1505   def _createVocabFile(self, basename):
   1506     vocabulary_file = os.path.join(self.get_temp_dir(), basename)
   1507     with open(vocabulary_file, "w") as f:
   1508       f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
   1509     return vocabulary_file
   1510 
   1511   def test_index_to_string_table(self):
   1512     vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
   1513     with self.test_session():
   1514       table = lookup.index_to_string_table_from_file(
   1515           vocabulary_file=vocabulary_file)
   1516       features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
   1517       self.assertRaises(errors_impl.OpError, features.eval)
   1518       lookup_ops.tables_initializer().run()
   1519       self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
   1520                           features.eval())
   1521 
   1522   def test_index_to_string_table_with_default_value(self):
   1523     default_value = b"NONE"
   1524     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
   1525     with self.test_session():
   1526       table = lookup.index_to_string_table_from_file(
   1527           vocabulary_file=vocabulary_file, default_value=default_value)
   1528       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
   1529       self.assertRaises(errors_impl.OpError, features.eval)
   1530       lookup_ops.tables_initializer().run()
   1531       self.assertAllEqual((b"salad", b"surgery", default_value),
   1532                           features.eval())
   1533 
   1534   def test_index_to_string_table_with_vocab_size_too_small(self):
   1535     default_value = b"NONE"
   1536     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
   1537     with self.test_session():
   1538       table = lookup.index_to_string_table_from_file(
   1539           vocabulary_file=vocabulary_file,
   1540           vocab_size=2,
   1541           default_value=default_value)
   1542       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
   1543       self.assertRaises(errors_impl.OpError, features.eval)
   1544       lookup_ops.tables_initializer().run()
   1545       self.assertAllEqual((b"salad", default_value, default_value),
   1546                           features.eval())
   1547 
   1548   def test_index_to_string_table_with_vocab_size_too_large(self):
   1549     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
   1550     with self.test_session():
   1551       table = lookup.index_to_string_table_from_file(
   1552           vocabulary_file=vocabulary_file, vocab_size=4)
   1553       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
   1554 
   1555       self.assertRaises(errors_impl.OpError, features.eval)
   1556       init = lookup_ops.tables_initializer()
   1557       self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1558                               "Invalid vocab_size", init.run)
   1559 
   1560   def test_index_to_string_table_with_vocab_size(self):
   1561     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
   1562     with self.test_session():
   1563       table = lookup.index_to_string_table_from_file(
   1564           vocabulary_file=vocabulary_file, vocab_size=3)
   1565       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
   1566 
   1567       self.assertRaises(errors_impl.OpError, features.eval)
   1568       lookup_ops.tables_initializer().run()
   1569       self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
   1570 
   1571 
   1572 class IndexToStringTableFromTensorTest(test.TestCase):
   1573 
   1574   def test_index_to_string_table_from_tensor(self):
   1575     with self.test_session():
   1576       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1577       table = lookup.index_to_string_table_from_tensor(
   1578           mapping=mapping_strings)
   1579 
   1580       indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
   1581       features = table.lookup(indices)
   1582       self.assertRaises(errors_impl.OpError, features.eval)
   1583       lookup_ops.tables_initializer().run()
   1584 
   1585       self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
   1586                           features.eval())
   1587 
   1588   def test_duplicate_entries(self):
   1589     with self.test_session():
   1590       mapping_strings = constant_op.constant(["hello", "hello"])
   1591       table = lookup.index_to_string_table_from_tensor(
   1592           mapping=mapping_strings)
   1593       indices = constant_op.constant([0, 1, 4], dtypes.int64)
   1594       features = table.lookup(indices)
   1595       lookup_ops.tables_initializer().run()
   1596       self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
   1597 
   1598   def test_index_to_string_with_default_value(self):
   1599     default_value = b"NONE"
   1600     with self.test_session():
   1601       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1602       table = lookup.index_to_string_table_from_tensor(
   1603           mapping=mapping_strings, default_value=default_value)
   1604       indices = constant_op.constant([1, 2, 4], dtypes.int64)
   1605       features = table.lookup(indices)
   1606       self.assertRaises(errors_impl.OpError, features.eval)
   1607 
   1608       lookup_ops.tables_initializer().run()
   1609       self.assertAllEqual((b"salad", b"surgery", default_value),
   1610                           features.eval())
   1611 
   1612 
   1613 class IndexToStringTest(test.TestCase):
   1614 
   1615   def test_index_to_string(self):
   1616     with self.test_session():
   1617       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1618       indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
   1619       feats = lookup.index_to_string(indices, mapping=mapping_strings)
   1620 
   1621       self.assertRaises(errors_impl.OpError, feats.eval)
   1622       lookup_ops.tables_initializer().run()
   1623 
   1624       self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
   1625                           feats.eval())
   1626 
   1627   def test_duplicate_entries(self):
   1628     with self.test_session():
   1629       mapping_strings = constant_op.constant(["hello", "hello"])
   1630       indices = constant_op.constant([0, 1, 4], dtypes.int64)
   1631       feats = lookup.index_to_string(indices, mapping=mapping_strings)
   1632       lookup_ops.tables_initializer().run()
   1633       self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
   1634 
   1635       self.assertRaises(errors_impl.OpError,
   1636                         lookup_ops.tables_initializer().run)
   1637 
   1638   def test_index_to_string_with_default_value(self):
   1639     default_value = b"NONE"
   1640     with self.test_session():
   1641       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
   1642       indices = constant_op.constant([1, 2, 4], dtypes.int64)
   1643       feats = lookup.index_to_string(
   1644           indices, mapping=mapping_strings, default_value=default_value)
   1645       self.assertRaises(errors_impl.OpError, feats.eval)
   1646 
   1647       lookup_ops.tables_initializer().run()
   1648       self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
   1649 
   1650 
   1651 class InitializeTableFromFileOpTest(test.TestCase):
   1652 
   1653   def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
   1654     vocabulary_file = os.path.join(self.get_temp_dir(), basename)
   1655     with open(vocabulary_file, "w") as f:
   1656       f.write("\n".join(values) + "\n")
   1657     return vocabulary_file
   1658 
   1659   @test_util.run_in_graph_and_eager_modes()
   1660   def testInitializeStringTable(self):
   1661     vocabulary_file = self._createVocabFile("one_column_1.txt")
   1662     default_value = -1
   1663     table = lookup.HashTable(
   1664         lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1665                                    lookup.TextFileIndex.WHOLE_LINE,
   1666                                    dtypes.int64,
   1667                                    lookup.TextFileIndex.LINE_NUMBER),
   1668         default_value)
   1669     self.evaluate(table.init)
   1670 
   1671     output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
   1672 
   1673     result = self.evaluate(output)
   1674     self.assertAllEqual([0, 1, -1], result)
   1675 
   1676   def testInitializeInt64Table(self):
   1677     vocabulary_file = self._createVocabFile(
   1678         "one_column_int64.txt", values=("42", "1", "-1000"))
   1679 
   1680     with self.test_session():
   1681       default_value = -1
   1682       table = lookup.HashTable(
   1683           lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
   1684                                      lookup.TextFileIndex.WHOLE_LINE,
   1685                                      dtypes.int64,
   1686                                      lookup.TextFileIndex.LINE_NUMBER),
   1687           default_value)
   1688       table.init.run()
   1689 
   1690       output = table.lookup(
   1691           constant_op.constant((42, 1, 11), dtype=dtypes.int64))
   1692 
   1693       result = output.eval()
   1694       self.assertAllEqual([0, 1, -1], result)
   1695 
   1696   def testInitializeIndexTable(self):
   1697     vocabulary_file = self._createVocabFile("one_column_2.txt")
   1698 
   1699     with self.test_session():
   1700       default_value = "UNK"
   1701       key_index = lookup.TextFileIndex.LINE_NUMBER
   1702       value_index = lookup.TextFileIndex.WHOLE_LINE
   1703       table = lookup.HashTable(
   1704           lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
   1705                                      key_index, dtypes.string, value_index),
   1706           default_value)
   1707       table.init.run()
   1708 
   1709       input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
   1710       output = table.lookup(input_values)
   1711 
   1712       result = output.eval()
   1713       self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result)
   1714 
   1715   def testMultiColumn(self):
   1716     vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
   1717     with open(vocabulary_file, "w") as f:
   1718       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
   1719 
   1720     with self.test_session():
   1721       default_value = -1
   1722       key_index = 1
   1723       value_index = 2
   1724 
   1725       table = lookup.HashTable(
   1726           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1727                                      key_index, dtypes.int64, value_index),
   1728           default_value)
   1729       table.init.run()
   1730 
   1731       input_string = constant_op.constant(["brain", "salad", "surgery"])
   1732       output = table.lookup(input_string)
   1733 
   1734       result = output.eval()
   1735       self.assertAllEqual([1, 5, 6], result)
   1736 
   1737   def testInvalidDataTypeInMultiColumn(self):
   1738     vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
   1739     with open(vocabulary_file, "w") as f:
   1740       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
   1741 
   1742     with self.test_session():
   1743       default_value = -1
   1744       key_index = 2
   1745       value_index = 1
   1746       table = lookup.HashTable(
   1747           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1748                                      key_index, dtypes.int64, value_index),
   1749           default_value)
   1750       with self.assertRaisesOpError("is not a valid"):
   1751         table.init.run()
   1752 
   1753   def testInvalidDataType(self):
   1754     vocabulary_file = self._createVocabFile("one_column_3.txt")
   1755 
   1756     with self.test_session():
   1757       default_value = "UNK"
   1758       key_index = lookup.TextFileIndex.WHOLE_LINE
   1759       value_index = lookup.TextFileIndex.LINE_NUMBER
   1760 
   1761       with self.assertRaises(ValueError):
   1762         lookup.HashTable(
   1763             lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
   1764                                        key_index, dtypes.string,
   1765                                        value_index), default_value)
   1766 
   1767   def testInvalidIndex(self):
   1768     vocabulary_file = self._createVocabFile("one_column_4.txt")
   1769     with self.test_session():
   1770       default_value = -1
   1771       key_index = 1  # second column of the line
   1772       value_index = lookup.TextFileIndex.LINE_NUMBER
   1773       table = lookup.HashTable(
   1774           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1775                                      key_index, dtypes.int64, value_index),
   1776           default_value)
   1777 
   1778       with self.assertRaisesOpError("Invalid number of columns"):
   1779         table.init.run()
   1780 
   1781   def testInitializeSameTableWithMultipleNodes(self):
   1782     vocabulary_file = self._createVocabFile("one_column_5.txt")
   1783 
   1784     with self.test_session() as sess:
   1785       shared_name = "shared-one-columm"
   1786       default_value = -1
   1787       table1 = lookup.HashTable(
   1788           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1789                                      lookup.TextFileIndex.WHOLE_LINE,
   1790                                      dtypes.int64,
   1791                                      lookup.TextFileIndex.LINE_NUMBER),
   1792           default_value,
   1793           shared_name=shared_name)
   1794       table2 = lookup.HashTable(
   1795           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1796                                      lookup.TextFileIndex.WHOLE_LINE,
   1797                                      dtypes.int64,
   1798                                      lookup.TextFileIndex.LINE_NUMBER),
   1799           default_value,
   1800           shared_name=shared_name)
   1801       table3 = lookup.HashTable(
   1802           lookup.TextFileInitializer(vocabulary_file, dtypes.string,
   1803                                      lookup.TextFileIndex.WHOLE_LINE,
   1804                                      dtypes.int64,
   1805                                      lookup.TextFileIndex.LINE_NUMBER),
   1806           default_value,
   1807           shared_name=shared_name)
   1808 
   1809       lookup_ops.tables_initializer().run()
   1810 
   1811       input_string = constant_op.constant(["brain", "salad", "tank"])
   1812 
   1813       output1 = table1.lookup(input_string)
   1814       output2 = table2.lookup(input_string)
   1815       output3 = table3.lookup(input_string)
   1816 
   1817       out1, out2, out3 = sess.run([output1, output2, output3])
   1818       self.assertAllEqual([0, 1, -1], out1)
   1819       self.assertAllEqual([0, 1, -1], out2)
   1820       self.assertAllEqual([0, 1, -1], out3)
   1821 
   1822   def testInitializeTableWithNoFilename(self):
   1823     with self.test_session():
   1824       default_value = -1
   1825       with self.assertRaises(ValueError):
   1826         lookup.HashTable(
   1827             lookup.TextFileInitializer(
   1828                 "", dtypes.string, lookup.TextFileIndex.WHOLE_LINE,
   1829                 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER),
   1830             default_value)
   1831 
   1832   def testInitializeWithVocabSize(self):
   1833     with self.test_session():
   1834       default_value = -1
   1835       vocab_size = 3
   1836       vocabulary_file1 = self._createVocabFile("one_column6.txt")
   1837       table1 = lookup.HashTable(
   1838           lookup.TextFileInitializer(
   1839               vocabulary_file1,
   1840               dtypes.string,
   1841               lookup.TextFileIndex.WHOLE_LINE,
   1842               dtypes.int64,
   1843               lookup.TextFileIndex.LINE_NUMBER,
   1844               vocab_size=vocab_size),
   1845           default_value)
   1846 
   1847       # Initialize from file.
   1848       table1.init.run()
   1849       self.assertEquals(vocab_size, table1.size().eval())
   1850 
   1851       vocabulary_file2 = self._createVocabFile("one_column7.txt")
   1852       vocab_size = 5
   1853       table2 = lookup.HashTable(
   1854           lookup.TextFileInitializer(
   1855               vocabulary_file2,
   1856               dtypes.string,
   1857               lookup.TextFileIndex.WHOLE_LINE,
   1858               dtypes.int64,
   1859               lookup.TextFileIndex.LINE_NUMBER,
   1860               vocab_size=vocab_size),
   1861           default_value)
   1862       with self.assertRaisesOpError("Invalid vocab_size"):
   1863         table2.init.run()
   1864 
   1865       vocab_size = 1
   1866       vocabulary_file3 = self._createVocabFile("one_column3.txt")
   1867       table3 = lookup.HashTable(
   1868           lookup.TextFileInitializer(
   1869               vocabulary_file3,
   1870               dtypes.string,
   1871               lookup.TextFileIndex.WHOLE_LINE,
   1872               dtypes.int64,
   1873               lookup.TextFileIndex.LINE_NUMBER,
   1874               vocab_size=vocab_size),
   1875           default_value)
   1876 
   1877       # Smaller vocab size reads only vocab_size records.
   1878       table3.init.run()
   1879       self.assertEquals(vocab_size, table3.size().eval())
   1880 
   1881   def testFeedVocabularyName(self):
   1882     vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
   1883 
   1884     with self.test_session():
   1885       default_value = -1
   1886       table = lookup.HashTable(
   1887           lookup.TextFileInitializer("old_file.txt", dtypes.string,
   1888                                      lookup.TextFileIndex.WHOLE_LINE,
   1889                                      dtypes.int64,
   1890                                      lookup.TextFileIndex.LINE_NUMBER),
   1891           default_value)
   1892 
   1893       # Initialize with non existing file (old_file.txt) should fail.
   1894       # TODO(yleon): Update message, which might change per FileSystem.
   1895       with self.assertRaisesOpError("old_file.txt"):
   1896         table.init.run()
   1897 
   1898       # Initialize the model feeding the vocabulary file.
   1899       filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
   1900       table.init.run(feed_dict={filenames[0]: vocabulary_file})
   1901 
   1902       input_string = constant_op.constant(["brain", "salad", "tank"])
   1903       output = table.lookup(input_string)
   1904 
   1905       result = output.eval()
   1906       self.assertAllEqual([0, 1, -1], result)
   1907 
   1908   def testInvalidFilenames(self):
   1909     vocabulary_file = self._createVocabFile("filename_shape.txt")
   1910 
   1911     with self.test_session():
   1912       default_value = -1
   1913 
   1914       # Invalid data type
   1915       other_type = constant_op.constant(1)
   1916       with self.assertRaises(ValueError):
   1917         lookup.HashTable(
   1918             lookup.TextFileInitializer(
   1919                 other_type, dtypes.string, lookup.TextFileIndex.WHOLE_LINE,
   1920                 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER),
   1921             default_value)
   1922 
   1923       # Non-scalar filename
   1924       filenames = constant_op.constant([vocabulary_file, vocabulary_file])
   1925       with self.assertRaises(ValueError):
   1926         lookup.HashTable(
   1927             lookup.TextFileInitializer(
   1928                 filenames, dtypes.string, lookup.TextFileIndex.WHOLE_LINE,
   1929                 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER),
   1930             default_value)
   1931 
   1932   def testIdToStringTable(self):
   1933     vocab_file = self._createVocabFile("feat_to_id_1.txt")
   1934     with self.test_session():
   1935       default_value = "UNK"
   1936       vocab_size = 3
   1937       table = lookup.HashTable(
   1938           lookup.TextFileStringTableInitializer(
   1939               vocab_file, vocab_size=vocab_size),
   1940           default_value)
   1941 
   1942       table.init.run()
   1943 
   1944       input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
   1945 
   1946       out = table.lookup(input_values)
   1947       self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], out.eval())
   1948       self.assertEquals(vocab_size, table.size().eval())
   1949 
   1950   def testStringToIdTable(self):
   1951     vocab_file = self._createVocabFile("feat_to_id_2.txt")
   1952     with self.test_session():
   1953       default_value = -1
   1954       vocab_size = 3
   1955       table = lookup.HashTable(
   1956           lookup.TextFileIdTableInitializer(
   1957               vocab_file, vocab_size=vocab_size),
   1958           default_value)
   1959       table.init.run()
   1960 
   1961       input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
   1962 
   1963       out = table.lookup(input_string)
   1964       self.assertAllEqual([0, 1, 2, -1], out.eval())
   1965       self.assertEquals(vocab_size, table.size().eval())
   1966 
   1967   def testInt64ToIdTable(self):
   1968     vocab_file = self._createVocabFile(
   1969         "feat_to_id_3.txt", values=("42", "1", "-1000"))
   1970     with self.test_session():
   1971       default_value = -1
   1972       vocab_size = 3
   1973       table = lookup.HashTable(
   1974           lookup.TextFileIdTableInitializer(
   1975               vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
   1976           default_value)
   1977       table.init.run()
   1978 
   1979       out = table.lookup(
   1980           constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
   1981       self.assertAllEqual((0, 1, 2, -1), out.eval())
   1982       self.assertEquals(vocab_size, table.size().eval())
   1983 
   1984 
   1985 class IdTableWithHashBucketsTest(test.TestCase):
   1986 
   1987   def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
   1988     vocabulary_file = os.path.join(self.get_temp_dir(), basename)
   1989     with open(vocabulary_file, "w") as f:
   1990       f.write("\n".join(values) + "\n")
   1991     return vocabulary_file
   1992 
   1993   def testStringIdTableWithHashBuckets(self):
   1994     vocab_file = self._createVocabFile("feat_to_id_1.txt")
   1995     with self.test_session():
   1996       default_value = -1
   1997       vocab_size = 3
   1998       oov_buckets = 1
   1999       table = lookup.IdTableWithHashBuckets(
   2000           lookup.HashTable(
   2001               lookup.TextFileIdTableInitializer(
   2002                   vocab_file, vocab_size=vocab_size),
   2003               default_value),
   2004           oov_buckets)
   2005 
   2006       table.init.run()
   2007 
   2008       input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
   2009 
   2010       out = table.lookup(input_string)
   2011       self.assertAllEqual([0, 1, 2, 3], out.eval())
   2012       self.assertEquals(vocab_size + oov_buckets, table.size().eval())
   2013 
   2014   def testInt32IdTableWithHashBuckets(self):
   2015     vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
   2016     with self.test_session():
   2017       default_value = -1
   2018       vocab_size = 3
   2019       oov_buckets = 1
   2020       table = lookup.IdTableWithHashBuckets(
   2021           lookup.HashTable(
   2022               lookup.TextFileIdTableInitializer(
   2023                   vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
   2024               default_value),
   2025           oov_buckets,
   2026           key_dtype=dtypes.int32)
   2027 
   2028       table.init.run()
   2029 
   2030       values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
   2031 
   2032       out = table.lookup(values)
   2033       self.assertAllEqual([0, 1, 2, 3], out.eval())
   2034       self.assertEquals(vocab_size + oov_buckets, table.size().eval())
   2035 
   2036   def testInt64IdTableWithHashBuckets(self):
   2037     vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
   2038     with self.test_session():
   2039       default_value = -1
   2040       vocab_size = 3
   2041       oov_buckets = 1
   2042       table = lookup.IdTableWithHashBuckets(
   2043           lookup.HashTable(
   2044               lookup.TextFileIdTableInitializer(
   2045                   vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
   2046               default_value),
   2047           oov_buckets)
   2048 
   2049       table.init.run()
   2050 
   2051       values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
   2052 
   2053       out = table.lookup(values)
   2054       self.assertAllEqual([0, 1, 2, 3], out.eval())
   2055       self.assertEquals(vocab_size + oov_buckets, table.size().eval())
   2056 
   2057   def testStringIdTableWithOnlyHashBucket(self):
   2058     with self.test_session():
   2059       oov_buckets = 5
   2060 
   2061       # Set a table that only uses hash buckets, for each input value returns
   2062       # an id calculated by fingerprint("input") mod oov_buckets.
   2063       table = lookup.IdTableWithHashBuckets(None, oov_buckets)
   2064       table.init.run()
   2065 
   2066       values = constant_op.constant(("brain", "salad", "surgery"))
   2067 
   2068       out = table.lookup(values)
   2069       self.assertAllEqual(
   2070           [
   2071               3,  # fingerprint("brain") mod 5.
   2072               1,  # fingerprint("salad") mod 5.
   2073               4  # fingerprint("surgery") mod 5
   2074           ],
   2075           out.eval())
   2076       self.assertEquals(oov_buckets, table.size().eval())
   2077 
   2078   def testInt32IdTableWithOnlyHashBucket(self):
   2079     with self.test_session():
   2080       oov_buckets = 5
   2081 
   2082       # Set a table that only uses hash buckets, for each input value returns
   2083       # an id calculated by fingerprint("input") mod oov_buckets.
   2084       table = lookup.IdTableWithHashBuckets(
   2085           None, oov_buckets, key_dtype=dtypes.int32)
   2086       table.init.run()
   2087 
   2088       input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32)
   2089 
   2090       out = table.lookup(input_string)
   2091       self.assertAllEqual(
   2092           [
   2093               1,  # fingerprint("42") mod 5.
   2094               4,  # fingerprint("1") mod 5.
   2095               2  # fingerprint("-1000") mod 5
   2096           ],
   2097           out.eval())
   2098       self.assertEquals(oov_buckets, table.size().eval())
   2099 
   2100   def testFloat64IdTableWithOnlyHashBucket(self):
   2101     with self.test_session():
   2102       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
   2103         lookup.IdTableWithHashBuckets(
   2104             None, num_oov_buckets=5, key_dtype=dtypes.float64)
   2105 
   2106   def testBoolIdTableWithOnlyHashBucket(self):
   2107     with self.test_session():
   2108       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
   2109         lookup.IdTableWithHashBuckets(
   2110             None, num_oov_buckets=5, key_dtype=dtypes.bool)
   2111 
   2112   def testIdTableWithHashBucketsWithMultipleInitializers(self):
   2113     vocab_file = self._createVocabFile("feat_to_id_4.txt")
   2114     with self.test_session() as sess:
   2115       default_value = -1
   2116       vocab_size = 3
   2117       oov_buckets = 3
   2118 
   2119       vocab_table = lookup.HashTable(
   2120           lookup.TextFileIdTableInitializer(
   2121               vocab_file, vocab_size=vocab_size),
   2122           default_value)
   2123       table1 = lookup.IdTableWithHashBuckets(
   2124           vocab_table,
   2125           oov_buckets,
   2126           hasher_spec=lookup.FastHashSpec,
   2127           name="table1")
   2128 
   2129       table2 = lookup.IdTableWithHashBuckets(
   2130           vocab_table,
   2131           oov_buckets,
   2132           hasher_spec=lookup.StrongHashSpec((1, 2)),
   2133           name="table2")
   2134 
   2135       lookup_ops.tables_initializer().run()
   2136 
   2137       input_string = constant_op.constant(
   2138           ["fruit", "brain", "salad", "surgery", "UNK"])
   2139 
   2140       out1 = table1.lookup(input_string)
   2141       out2 = table2.lookup(input_string)
   2142 
   2143       out1, out2 = sess.run([out1, out2])
   2144       self.assertAllEqual([5, 0, 1, 2, 5], out1)
   2145       self.assertAllEqual([5, 0, 1, 2, 3], out2)
   2146       self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
   2147       self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
   2148       test_util.assert_ops_in_graph({
   2149           "table1_Lookup/hash_bucket": "StringToHashBucketFast",
   2150           "table2_Lookup/hash_bucket": "StringToHashBucketStrong",
   2151       }, sess.graph)
   2152 
   2153   def testIdTableWithHashBucketsInitializationAcrossSessions(self):
   2154     vocab_file = self._createVocabFile("feat_to_id_5.txt")
   2155     shared_name = "across-sessions"
   2156     with self.test_session():
   2157       default_value = -1
   2158       vocab_size = 3
   2159       oov_buckets = 1
   2160       table1 = lookup.IdTableWithHashBuckets(
   2161           lookup.HashTable(
   2162               lookup.TextFileIdTableInitializer(
   2163                   vocab_file, vocab_size=vocab_size),
   2164               default_value,
   2165               shared_name=shared_name),
   2166           oov_buckets)
   2167 
   2168       table1.init.run()
   2169 
   2170       input_string_1 = constant_op.constant(
   2171           ["brain", "salad", "surgery", "UNK"])
   2172 
   2173       out1 = table1.lookup(input_string_1)
   2174 
   2175       self.assertAllEqual([0, 1, 2, 3], out1.eval())
   2176       self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
   2177 
   2178     with self.test_session():
   2179       default_value = -1
   2180       vocab_size = 3
   2181       oov_buckets = 1
   2182 
   2183       # Underlying lookup table already initialized in previous session.
   2184       # No need to call table2.init.run()
   2185       table2 = lookup.IdTableWithHashBuckets(
   2186           lookup.HashTable(
   2187               lookup.TextFileIdTableInitializer(
   2188                   vocab_file, vocab_size=vocab_size),
   2189               default_value,
   2190               shared_name=shared_name),
   2191           oov_buckets)
   2192 
   2193       input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
   2194 
   2195       out2 = table2.lookup(input_string_2)
   2196 
   2197       self.assertAllEqual([3, 1, 3], out2.eval())
   2198       self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
   2199 
   2200   def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
   2201     vocab_file = self._createVocabFile("feat_to_id_6.txt")
   2202     with self.test_session() as sess:
   2203       default_value1 = -1
   2204       vocab_size = 3
   2205       oov_buckets = 0
   2206       table1 = lookup.IdTableWithHashBuckets(
   2207           lookup.HashTable(
   2208               lookup.TextFileIdTableInitializer(
   2209                   vocab_file, vocab_size=vocab_size),
   2210               default_value1),
   2211           oov_buckets)
   2212 
   2213       default_value2 = -2
   2214       table2 = lookup.IdTableWithHashBuckets(
   2215           lookup.HashTable(
   2216               lookup.TextFileIdTableInitializer(
   2217                   vocab_file, vocab_size=vocab_size),
   2218               default_value2),
   2219           oov_buckets)
   2220 
   2221       lookup_ops.tables_initializer().run()
   2222 
   2223       input_string_1 = constant_op.constant(
   2224           ["brain", "salad", "surgery", "UNK"])
   2225       input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
   2226 
   2227       out1 = table1.lookup(input_string_1)
   2228       out2 = table2.lookup(input_string_2)
   2229 
   2230       out1, out2 = sess.run([out1, out2])
   2231       self.assertAllEqual([0, 1, 2, -1], out1)
   2232       self.assertAllEqual([-2, 1, -2], out2)
   2233       self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
   2234       self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
   2235 
   2236   def testSparseTensor(self):
   2237     vocab_file = self._createVocabFile("feat_to_id_7.txt")
   2238     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
   2239     input_shape = [4, 4]
   2240     with self.test_session() as sess:
   2241       sp_features = sparse_tensor.SparseTensor(
   2242           constant_op.constant(input_indices, dtypes.int64),
   2243           constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
   2244                                dtypes.string),
   2245           constant_op.constant(input_shape, dtypes.int64))
   2246 
   2247       table = lookup.IdTableWithHashBuckets(
   2248           lookup.HashTable(
   2249               lookup.TextFileIdTableInitializer(
   2250                   vocab_file, vocab_size=3),
   2251               -1),
   2252           1)
   2253       table.init.run()
   2254 
   2255       sp_ids = table.lookup(sp_features)
   2256 
   2257       self.assertAllEqual([5], sp_ids.values._shape_as_list())
   2258 
   2259       sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
   2260           [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
   2261 
   2262       self.assertAllEqual(input_indices, sp_ids_ind)
   2263       self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
   2264       self.assertAllEqual(input_shape, sp_ids_shape)
   2265 
   2266   def testInt32SparseTensor(self):
   2267     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
   2268     input_shape = [4, 4]
   2269     with self.test_session() as sess:
   2270       sp_features = sparse_tensor.SparseTensor(
   2271           constant_op.constant(input_indices, dtypes.int64),
   2272           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
   2273           constant_op.constant(input_shape, dtypes.int64))
   2274 
   2275       table = lookup.IdTableWithHashBuckets(
   2276           lookup.HashTable(
   2277               lookup.KeyValueTensorInitializer(
   2278                   (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64),
   2279               -1),
   2280           1,
   2281           key_dtype=dtypes.int32)
   2282       table.init.run()
   2283 
   2284       sp_ids = table.lookup(sp_features)
   2285 
   2286       self.assertAllEqual([5], sp_ids.values._shape_as_list())
   2287 
   2288       sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
   2289           [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
   2290 
   2291       self.assertAllEqual(input_indices, sp_ids_ind)
   2292       self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
   2293       self.assertAllEqual(input_shape, sp_ids_shape)
   2294 
   2295   def testInt64SparseTensor(self):
   2296     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
   2297     input_shape = [4, 4]
   2298     with self.test_session() as sess:
   2299       sp_features = sparse_tensor.SparseTensor(
   2300           constant_op.constant(input_indices, dtypes.int64),
   2301           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
   2302           constant_op.constant(input_shape, dtypes.int64))
   2303 
   2304       table = lookup.IdTableWithHashBuckets(
   2305           lookup.HashTable(
   2306               lookup.KeyValueTensorInitializer(
   2307                   (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64),
   2308               -1),
   2309           1,
   2310           key_dtype=dtypes.int64)
   2311       table.init.run()
   2312 
   2313       sp_ids = table.lookup(sp_features)
   2314 
   2315       self.assertAllEqual([5], sp_ids.values._shape_as_list())
   2316 
   2317       sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
   2318           [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
   2319 
   2320       self.assertAllEqual(input_indices, sp_ids_ind)
   2321       self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
   2322       self.assertAllEqual(input_shape, sp_ids_shape)
   2323 
   2324   def testIdTableWithHashBucketsWithInvalidHashers(self):
   2325     vocab_file = self._createVocabFile("feat_to_id_4.txt")
   2326     with self.test_session():
   2327       default_value = -1
   2328       vocab_size = 3
   2329       oov_buckets = 1
   2330       lookup_table = lookup.HashTable(
   2331           lookup.TextFileIdTableInitializer(
   2332               vocab_file, vocab_size=vocab_size),
   2333           default_value)
   2334 
   2335       with self.assertRaises(TypeError):
   2336         lookup.IdTableWithHashBuckets(
   2337             lookup_table, oov_buckets, hasher_spec=1)
   2338 
   2339       table = lookup.IdTableWithHashBuckets(
   2340           lookup_table,
   2341           oov_buckets,
   2342           hasher_spec=lookup.HasherSpec("my-awesome-hash", None))
   2343 
   2344       input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
   2345 
   2346       with self.assertRaises(ValueError):
   2347         table.lookup(input_string)
   2348 
   2349       with self.assertRaises(ValueError):
   2350         table = lookup.IdTableWithHashBuckets(
   2351             lookup_table,
   2352             oov_buckets,
   2353             hasher_spec=lookup.StrongHashSpec([]))
   2354 
   2355       with self.assertRaises(ValueError):
   2356         table = lookup.IdTableWithHashBuckets(
   2357             lookup_table,
   2358             oov_buckets,
   2359             hasher_spec=lookup.StrongHashSpec([1, 2, 3]))
   2360 
   2361       with self.assertRaises(TypeError):
   2362         table = lookup.IdTableWithHashBuckets(
   2363             lookup_table,
   2364             oov_buckets,
   2365             hasher_spec=lookup.StrongHashSpec([None, 2]))
   2366 
   2367 
   2368 if __name__ == "__main__":
   2369   test.main()
   2370