1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.lookup.lookup.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import os 21 import tempfile 22 import numpy as np 23 import six 24 25 from tensorflow.contrib import lookup 26 from tensorflow.python.client import session 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import errors_impl 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import sparse_tensor 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import lookup_ops 35 from tensorflow.python.ops import variables 36 from tensorflow.python.platform import test 37 from tensorflow.python.training import saver 38 from tensorflow.python.training import server_lib 39 40 41 class HashTableOpTest(test.TestCase): 42 43 def testHashTable(self): 44 with self.test_session(): 45 default_val = -1 46 keys = constant_op.constant(["brain", "salad", "surgery"]) 47 values = constant_op.constant([0, 1, 2], dtypes.int64) 48 table = lookup.HashTable( 49 lookup.KeyValueTensorInitializer(keys, values), default_val) 50 table.init.run() 51 52 self.assertAllEqual(3, table.size().eval()) 53 54 input_string = constant_op.constant(["brain", "salad", "tank"]) 55 output = table.lookup(input_string) 56 self.assertAllEqual([3], output.get_shape()) 57 58 result = output.eval() 59 self.assertAllEqual([0, 1, -1], result) 60 61 def testHashTableFindHighRank(self): 62 with self.test_session(): 63 default_val = -1 64 keys = constant_op.constant(["brain", "salad", "surgery"]) 65 values = constant_op.constant([0, 1, 2], dtypes.int64) 66 table = lookup.HashTable( 67 lookup.KeyValueTensorInitializer(keys, values), default_val) 68 table.init.run() 69 70 self.assertAllEqual(3, table.size().eval()) 71 72 input_string = constant_op.constant( 73 [["brain", "salad"], ["tank", "tarkus"]]) 74 output = table.lookup(input_string) 75 76 result = output.eval() 77 self.assertAllEqual([[0, 1], [-1, -1]], result) 78 79 def testHashTableInitWithPythonArrays(self): 80 with self.test_session(): 81 default_val = -1 82 keys = ["brain", "salad", "surgery"] 83 values = [0, 1, 2] 84 table = lookup.HashTable( 85 lookup.KeyValueTensorInitializer( 86 keys, values, value_dtype=dtypes.int64), 87 default_val) 88 table.init.run() 89 90 self.assertAllEqual(3, table.size().eval()) 91 92 input_string = constant_op.constant(["brain", "salad", "tank"]) 93 output = table.lookup(input_string) 94 95 result = output.eval() 96 self.assertAllEqual([0, 1, -1], result) 97 98 def testHashTableInitWithNumPyArrays(self): 99 with self.test_session(): 100 default_val = -1 101 keys = np.array(["brain", "salad", "surgery"], dtype=np.str) 102 values = np.array([0, 1, 2], dtype=np.int64) 103 table = lookup.HashTable( 104 lookup.KeyValueTensorInitializer(keys, values), default_val) 105 table.init.run() 106 107 self.assertAllEqual(3, table.size().eval()) 108 109 input_string = constant_op.constant(["brain", "salad", "tank"]) 110 output = table.lookup(input_string) 111 112 result = output.eval() 113 self.assertAllEqual([0, 1, -1], result) 114 115 def testMultipleHashTables(self): 116 with self.test_session() as sess: 117 default_val = -1 118 keys = constant_op.constant(["brain", "salad", "surgery"]) 119 values = constant_op.constant([0, 1, 2], dtypes.int64) 120 121 table1 = lookup.HashTable( 122 lookup.KeyValueTensorInitializer(keys, values), default_val) 123 table2 = lookup.HashTable( 124 lookup.KeyValueTensorInitializer(keys, values), default_val) 125 table3 = lookup.HashTable( 126 lookup.KeyValueTensorInitializer(keys, values), default_val) 127 128 lookup_ops.tables_initializer().run() 129 self.assertAllEqual(3, table1.size().eval()) 130 self.assertAllEqual(3, table2.size().eval()) 131 self.assertAllEqual(3, table3.size().eval()) 132 133 input_string = constant_op.constant(["brain", "salad", "tank"]) 134 output1 = table1.lookup(input_string) 135 output2 = table2.lookup(input_string) 136 output3 = table3.lookup(input_string) 137 138 out1, out2, out3 = sess.run([output1, output2, output3]) 139 self.assertAllEqual([0, 1, -1], out1) 140 self.assertAllEqual([0, 1, -1], out2) 141 self.assertAllEqual([0, 1, -1], out3) 142 143 def testHashTableWithTensorDefault(self): 144 with self.test_session(): 145 default_val = constant_op.constant(-1, dtypes.int64) 146 keys = constant_op.constant(["brain", "salad", "surgery"]) 147 values = constant_op.constant([0, 1, 2], dtypes.int64) 148 table = lookup.HashTable( 149 lookup.KeyValueTensorInitializer(keys, values), default_val) 150 table.init.run() 151 152 input_string = constant_op.constant(["brain", "salad", "tank"]) 153 output = table.lookup(input_string) 154 155 result = output.eval() 156 self.assertAllEqual([0, 1, -1], result) 157 158 def testHashTableWithSparseTensorInput(self): 159 with self.test_session() as sess: 160 default_val = constant_op.constant(-1, dtypes.int64) 161 keys = constant_op.constant(["brain", "salad", "surgery"]) 162 values = constant_op.constant([0, 1, 2], dtypes.int64) 163 table = lookup.HashTable( 164 lookup.KeyValueTensorInitializer(keys, values), default_val) 165 table.init.run() 166 167 sp_indices = [[0, 0], [0, 1], [1, 0]] 168 sp_shape = [2, 2] 169 input_tensor = sparse_tensor.SparseTensor( 170 constant_op.constant(sp_indices, dtypes.int64), 171 constant_op.constant(["brain", "salad", "tank"]), 172 constant_op.constant(sp_shape, dtypes.int64)) 173 output = table.lookup(input_tensor) 174 175 out_indices, out_values, out_shape = sess.run(output) 176 177 self.assertAllEqual([0, 1, -1], out_values) 178 self.assertAllEqual(sp_indices, out_indices) 179 self.assertAllEqual(sp_shape, out_shape) 180 181 def testSignatureMismatch(self): 182 with self.test_session(): 183 default_val = -1 184 keys = constant_op.constant(["brain", "salad", "surgery"]) 185 values = constant_op.constant([0, 1, 2], dtypes.int64) 186 table = lookup.HashTable( 187 lookup.KeyValueTensorInitializer(keys, values), default_val) 188 table.init.run() 189 190 # Ref types do not produce a lookup signature mismatch. 191 input_string_ref = variables.Variable("brain") 192 variables.global_variables_initializer().run() 193 self.assertEqual(0, table.lookup(input_string_ref).eval()) 194 195 input_string = constant_op.constant([1, 2, 3], dtypes.int64) 196 with self.assertRaises(TypeError): 197 table.lookup(input_string) 198 199 with self.assertRaises(TypeError): 200 lookup.HashTable( 201 lookup.KeyValueTensorInitializer(keys, values), "UNK") 202 203 def testDTypes(self): 204 with self.test_session(): 205 default_val = -1 206 with self.assertRaises(TypeError): 207 lookup.HashTable( 208 lookup.KeyValueTensorInitializer(["a"], [1], [dtypes.string], 209 dtypes.int64), default_val) 210 211 def testNotInitialized(self): 212 with self.test_session(): 213 default_val = -1 214 table = lookup.HashTable( 215 lookup.KeyValueTensorInitializer( 216 ["a"], [1], value_dtype=dtypes.int64), 217 default_val) 218 219 input_string = constant_op.constant(["brain", "salad", "surgery"]) 220 output = table.lookup(input_string) 221 222 with self.assertRaisesOpError("Table not initialized"): 223 output.eval() 224 225 def testInitializeTwice(self): 226 with self.test_session(): 227 default_val = -1 228 keys = constant_op.constant(["brain", "salad", "surgery"]) 229 values = constant_op.constant([0, 1, 2], dtypes.int64) 230 table = lookup.HashTable( 231 lookup.KeyValueTensorInitializer(keys, values), default_val) 232 table.init.run() 233 234 with self.assertRaisesOpError("Table already initialized"): 235 table.init.run() 236 237 def testInitializationWithInvalidDimensions(self): 238 with self.test_session(): 239 default_val = -1 240 keys = constant_op.constant(["brain", "salad", "surgery"]) 241 values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) 242 243 with self.assertRaises(ValueError): 244 lookup.HashTable( 245 lookup.KeyValueTensorInitializer(keys, values), default_val) 246 247 def testMultipleSessions(self): 248 # Start a server 249 server = server_lib.Server( 250 { 251 "local0": ["localhost:0"] 252 }, protocol="grpc", start=True) 253 # Create two sessions sharing the same state 254 session1 = session.Session(server.target) 255 session2 = session.Session(server.target) 256 257 default_val = -1 258 keys = constant_op.constant(["brain", "salad", "surgery"]) 259 values = constant_op.constant([0, 1, 2], dtypes.int64) 260 table = lookup.HashTable( 261 lookup.KeyValueTensorInitializer(keys, values), 262 default_val, 263 name="t1") 264 265 # Init the table in the first session. 266 with session1: 267 table.init.run() 268 self.assertAllEqual(3, table.size().eval()) 269 270 # Init the table in the second session and verify that we do not get a 271 # "Table already initialized" error. 272 with session2: 273 table.init.run() 274 self.assertAllEqual(3, table.size().eval()) 275 276 277 class MutableHashTableOpTest(test.TestCase): 278 279 def testMutableHashTable(self): 280 with self.test_session(): 281 default_val = -1 282 keys = constant_op.constant(["brain", "salad", "surgery"]) 283 values = constant_op.constant([0, 1, 2], dtypes.int64) 284 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 285 default_val) 286 self.assertAllEqual(0, table.size().eval()) 287 288 table.insert(keys, values).run() 289 self.assertAllEqual(3, table.size().eval()) 290 291 input_string = constant_op.constant(["brain", "salad", "tank"]) 292 output = table.lookup(input_string) 293 self.assertAllEqual([3], output.get_shape()) 294 295 result = output.eval() 296 self.assertAllEqual([0, 1, -1], result) 297 298 exported_keys, exported_values = table.export() 299 self.assertAllEqual([None], exported_keys.get_shape().as_list()) 300 self.assertAllEqual([None], exported_values.get_shape().as_list()) 301 302 # exported data is in the order of the internal map, i.e. undefined 303 sorted_keys = np.sort(exported_keys.eval()) 304 sorted_values = np.sort(exported_values.eval()) 305 self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) 306 self.assertAllEqual([0, 1, 2], sorted_values) 307 308 def testSaveRestore(self): 309 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 310 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 311 312 with self.test_session(graph=ops.Graph()) as sess: 313 v0 = variables.Variable(10.0, name="v0") 314 v1 = variables.Variable(20.0, name="v1") 315 316 default_val = -1 317 keys = constant_op.constant(["b", "c", "d"], dtypes.string) 318 values = constant_op.constant([0, 1, 2], dtypes.int64) 319 table = lookup.MutableHashTable( 320 dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) 321 322 save = saver.Saver() 323 variables.global_variables_initializer().run() 324 325 # Check that the parameter nodes have been initialized. 326 self.assertEqual(10.0, v0.eval()) 327 self.assertEqual(20.0, v1.eval()) 328 329 self.assertAllEqual(0, table.size().eval()) 330 table.insert(keys, values).run() 331 self.assertAllEqual(3, table.size().eval()) 332 333 val = save.save(sess, save_path) 334 self.assertTrue(isinstance(val, six.string_types)) 335 self.assertEqual(save_path, val) 336 337 with self.test_session(graph=ops.Graph()) as sess: 338 v0 = variables.Variable(-1.0, name="v0") 339 v1 = variables.Variable(-1.0, name="v1") 340 default_val = -1 341 table = lookup.MutableHashTable( 342 dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) 343 table.insert( 344 constant_op.constant(["a", "c"], dtypes.string), 345 constant_op.constant([12, 24], dtypes.int64)).run() 346 self.assertAllEqual(2, table.size().eval()) 347 348 save = saver.Saver() 349 350 # Restore the saved values in the parameter nodes. 351 save.restore(sess, save_path) 352 # Check that the parameter nodes have been restored. 353 self.assertEqual(10.0, v0.eval()) 354 self.assertEqual(20.0, v1.eval()) 355 356 self.assertAllEqual(3, table.size().eval()) 357 358 input_string = constant_op.constant(["a", "b", "c", "d", "e"], 359 dtypes.string) 360 output = table.lookup(input_string) 361 self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) 362 363 def testSharing(self): 364 # Start a server to store the table state 365 server = server_lib.Server( 366 { 367 "local0": ["localhost:0"] 368 }, protocol="grpc", start=True) 369 # Create two sessions sharing the same state 370 session1 = session.Session(server.target) 371 session2 = session.Session(server.target) 372 373 table = lookup.MutableHashTable( 374 dtypes.int64, dtypes.string, "-", name="t1") 375 376 # Populate the table in the first session 377 with session1: 378 self.assertAllEqual(0, table.size().eval()) 379 380 keys = constant_op.constant([11, 12], dtypes.int64) 381 values = constant_op.constant(["a", "b"]) 382 table.insert(keys, values).run() 383 self.assertAllEqual(2, table.size().eval()) 384 385 output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) 386 self.assertAllEqual([b"a", b"b", b"-"], output.eval()) 387 388 # Verify that we can access the shared data from the second session 389 with session2: 390 self.assertAllEqual(2, table.size().eval()) 391 392 output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) 393 self.assertAllEqual([b"-", b"a", b"b"], output.eval()) 394 395 def testMutableHashTableOfTensors(self): 396 with self.test_session(): 397 default_val = constant_op.constant([-1, -1], dtypes.int64) 398 keys = constant_op.constant(["brain", "salad", "surgery"]) 399 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 400 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 401 default_val) 402 self.assertAllEqual(0, table.size().eval()) 403 404 table.insert(keys, values).run() 405 self.assertAllEqual(3, table.size().eval()) 406 407 input_string = constant_op.constant(["brain", "salad", "tank"]) 408 output = table.lookup(input_string) 409 self.assertAllEqual([3, 2], output.get_shape()) 410 411 result = output.eval() 412 self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) 413 414 exported_keys, exported_values = table.export() 415 self.assertAllEqual([None], exported_keys.get_shape().as_list()) 416 self.assertAllEqual([None, 2], exported_values.get_shape().as_list()) 417 # exported data is in the order of the internal map, i.e. undefined 418 sorted_keys = np.sort(exported_keys.eval()) 419 sorted_values = np.sort(exported_values.eval()) 420 self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) 421 self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) 422 423 def testMutableHashTableExportInsert(self): 424 with self.test_session(): 425 default_val = constant_op.constant([-1, -1], dtypes.int64) 426 keys = constant_op.constant(["brain", "salad", "surgery"]) 427 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 428 table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, 429 default_val) 430 self.assertAllEqual(0, table1.size().eval()) 431 table1.insert(keys, values).run() 432 self.assertAllEqual(3, table1.size().eval()) 433 434 input_string = constant_op.constant(["brain", "salad", "tank"]) 435 expected_output = [[0, 1], [2, 3], [-1, -1]] 436 output1 = table1.lookup(input_string) 437 self.assertAllEqual(expected_output, output1.eval()) 438 439 exported_keys, exported_values = table1.export() 440 self.assertAllEqual(3, exported_keys.eval().size) 441 self.assertAllEqual(6, exported_values.eval().size) 442 443 # Populate a second table from the exported data 444 table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, 445 default_val) 446 self.assertAllEqual(0, table2.size().eval()) 447 table2.insert(exported_keys, exported_values).run() 448 self.assertAllEqual(3, table2.size().eval()) 449 450 # Verify lookup result is still the same 451 output2 = table2.lookup(input_string) 452 self.assertAllEqual(expected_output, output2.eval()) 453 454 def testMutableHashTableOfTensorsInvalidShape(self): 455 with self.test_session(): 456 default_val = constant_op.constant([-1, -1], dtypes.int64) 457 keys = constant_op.constant(["brain", "salad", "surgery"]) 458 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 459 default_val) 460 461 # Shape [6] instead of [3, 2] 462 values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) 463 with self.assertRaisesOpError("Expected shape"): 464 table.insert(keys, values).run() 465 466 # Shape [2,3] instead of [3, 2] 467 values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) 468 with self.assertRaisesOpError("Expected shape"): 469 table.insert(keys, values).run() 470 471 # Shape [2, 2] instead of [3, 2] 472 values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) 473 with self.assertRaisesOpError("Expected shape"): 474 table.insert(keys, values).run() 475 476 # Shape [3, 1] instead of [3, 2] 477 values = constant_op.constant([[0], [2], [4]], dtypes.int64) 478 with self.assertRaisesOpError("Expected shape"): 479 table.insert(keys, values).run() 480 481 # Valid Insert 482 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 483 table.insert(keys, values).run() 484 self.assertAllEqual(3, table.size().eval()) 485 486 def testMutableHashTableInvalidDefaultValue(self): 487 with self.test_session(): 488 default_val = constant_op.constant([[-1, -1]], dtypes.int64) 489 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 490 default_val) 491 with self.assertRaisesOpError("Default value must be a vector"): 492 self.assertAllEqual(0, table.size().eval()) 493 494 def testMutableHashTableDuplicateInsert(self): 495 with self.test_session(): 496 default_val = -1 497 keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) 498 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 499 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 500 default_val) 501 self.assertAllEqual(0, table.size().eval()) 502 503 table.insert(keys, values).run() 504 self.assertAllEqual(3, table.size().eval()) 505 506 input_string = constant_op.constant(["brain", "salad", "tank"]) 507 output = table.lookup(input_string) 508 509 result = output.eval() 510 self.assertAllEqual([3, 1, -1], result) 511 512 def testMutableHashTableFindHighRank(self): 513 with self.test_session(): 514 default_val = -1 515 keys = constant_op.constant(["brain", "salad", "surgery"]) 516 values = constant_op.constant([0, 1, 2], dtypes.int64) 517 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 518 default_val) 519 520 table.insert(keys, values).run() 521 self.assertAllEqual(3, table.size().eval()) 522 523 input_string = constant_op.constant( 524 [["brain", "salad"], ["tank", "tarkus"]]) 525 output = table.lookup(input_string) 526 self.assertAllEqual([2, 2], output.get_shape()) 527 528 result = output.eval() 529 self.assertAllEqual([[0, 1], [-1, -1]], result) 530 531 def testMutableHashTableInsertHighRank(self): 532 with self.test_session(): 533 default_val = -1 534 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) 535 values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) 536 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 537 default_val) 538 539 table.insert(keys, values).run() 540 self.assertAllEqual(4, table.size().eval()) 541 542 input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) 543 output = table.lookup(input_string) 544 545 result = output.eval() 546 self.assertAllEqual([0, 1, 3, -1], result) 547 548 def testMutableHashTableOfTensorsFindHighRank(self): 549 with self.test_session(): 550 default_val = constant_op.constant([-1, -1, -1], dtypes.int64) 551 keys = constant_op.constant(["brain", "salad", "surgery"]) 552 values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], 553 dtypes.int64) 554 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 555 default_val) 556 557 table.insert(keys, values).run() 558 self.assertAllEqual(3, table.size().eval()) 559 560 input_string = constant_op.constant( 561 [["brain", "salad"], ["tank", "tarkus"]]) 562 output = table.lookup(input_string) 563 self.assertAllEqual([2, 2, 3], output.get_shape()) 564 565 result = output.eval() 566 self.assertAllEqual( 567 [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) 568 569 def testMultipleMutableHashTables(self): 570 with self.test_session() as sess: 571 default_val = -1 572 keys = constant_op.constant(["brain", "salad", "surgery"]) 573 values = constant_op.constant([0, 1, 2], dtypes.int64) 574 575 table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, 576 default_val) 577 table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, 578 default_val) 579 table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64, 580 default_val) 581 table1.insert(keys, values).run() 582 table2.insert(keys, values).run() 583 table3.insert(keys, values).run() 584 585 self.assertAllEqual(3, table1.size().eval()) 586 self.assertAllEqual(3, table2.size().eval()) 587 self.assertAllEqual(3, table3.size().eval()) 588 589 input_string = constant_op.constant(["brain", "salad", "tank"]) 590 output1 = table1.lookup(input_string) 591 output2 = table2.lookup(input_string) 592 output3 = table3.lookup(input_string) 593 594 out1, out2, out3 = sess.run([output1, output2, output3]) 595 self.assertAllEqual([0, 1, -1], out1) 596 self.assertAllEqual([0, 1, -1], out2) 597 self.assertAllEqual([0, 1, -1], out3) 598 599 def testMutableHashTableWithTensorDefault(self): 600 with self.test_session(): 601 default_val = constant_op.constant(-1, dtypes.int64) 602 keys = constant_op.constant(["brain", "salad", "surgery"]) 603 values = constant_op.constant([0, 1, 2], dtypes.int64) 604 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 605 default_val) 606 607 table.insert(keys, values).run() 608 self.assertAllEqual(3, table.size().eval()) 609 610 input_string = constant_op.constant(["brain", "salad", "tank"]) 611 output = table.lookup(input_string) 612 613 result = output.eval() 614 self.assertAllEqual([0, 1, -1], result) 615 616 def testSignatureMismatch(self): 617 with self.test_session(): 618 default_val = -1 619 keys = constant_op.constant(["brain", "salad", "surgery"]) 620 values = constant_op.constant([0, 1, 2], dtypes.int64) 621 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, 622 default_val) 623 624 # insert with keys of the wrong type 625 with self.assertRaises(TypeError): 626 table.insert(constant_op.constant([4, 5, 6]), values).run() 627 628 # insert with values of the wrong type 629 with self.assertRaises(TypeError): 630 table.insert(keys, constant_op.constant(["a", "b", "c"])).run() 631 632 self.assertAllEqual(0, table.size().eval()) 633 634 table.insert(keys, values).run() 635 self.assertAllEqual(3, table.size().eval()) 636 637 input_string_ref = variables.Variable("brain") 638 input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) 639 variables.global_variables_initializer().run() 640 641 # Ref types do not produce an insert signature mismatch. 642 table.insert(input_string_ref, input_int64_ref).run() 643 self.assertAllEqual(3, table.size().eval()) 644 645 # Ref types do not produce a lookup signature mismatch. 646 self.assertEqual(-1, table.lookup(input_string_ref).eval()) 647 648 # lookup with keys of the wrong type 649 input_string = constant_op.constant([1, 2, 3], dtypes.int64) 650 with self.assertRaises(TypeError): 651 table.lookup(input_string).eval() 652 653 # default value of the wrong type 654 with self.assertRaises(TypeError): 655 lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") 656 657 def testMutableHashTableStringFloat(self): 658 with self.test_session(): 659 default_val = -1.5 660 keys = constant_op.constant(["brain", "salad", "surgery"]) 661 values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) 662 table = lookup.MutableHashTable(dtypes.string, dtypes.float32, 663 default_val) 664 self.assertAllEqual(0, table.size().eval()) 665 666 table.insert(keys, values).run() 667 self.assertAllEqual(3, table.size().eval()) 668 669 input_string = constant_op.constant(["brain", "salad", "tank"]) 670 output = table.lookup(input_string) 671 672 result = output.eval() 673 self.assertAllClose([0, 1.1, default_val], result) 674 675 def testMutableHashTableIntFloat(self): 676 with self.test_session(): 677 default_val = -1.0 678 keys = constant_op.constant([3, 7, 0], dtypes.int64) 679 values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) 680 table = lookup.MutableHashTable(dtypes.int64, dtypes.float32, 681 default_val) 682 self.assertAllEqual(0, table.size().eval()) 683 684 table.insert(keys, values).run() 685 self.assertAllEqual(3, table.size().eval()) 686 687 input_string = constant_op.constant([7, 0, 11], dtypes.int64) 688 output = table.lookup(input_string) 689 690 result = output.eval() 691 self.assertAllClose([-1.2, 9.9, default_val], result) 692 693 def testMutableHashTableInt64String(self): 694 with self.test_session(): 695 default_val = "n/a" 696 keys = constant_op.constant([0, 1, 2], dtypes.int64) 697 values = constant_op.constant(["brain", "salad", "surgery"]) 698 table = lookup.MutableHashTable(dtypes.int64, dtypes.string, 699 default_val) 700 self.assertAllEqual(0, table.size().eval()) 701 702 table.insert(keys, values).run() 703 self.assertAllEqual(3, table.size().eval()) 704 705 input_string = constant_op.constant([0, 1, 3], dtypes.int64) 706 output = table.lookup(input_string) 707 708 result = output.eval() 709 self.assertAllEqual((b"brain", b"salad", b"n/a"), result) 710 711 712 class MutableDenseHashTableOpTest(test.TestCase): 713 714 def testBasic(self): 715 with self.test_session(): 716 keys = constant_op.constant([11, 12, 13], dtypes.int64) 717 values = constant_op.constant([0, 1, 2], dtypes.int64) 718 table = lookup.MutableDenseHashTable( 719 dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) 720 self.assertAllEqual(0, table.size().eval()) 721 722 table.insert(keys, values).run() 723 self.assertAllEqual(3, table.size().eval()) 724 725 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 726 output = table.lookup(input_string) 727 self.assertAllEqual([3], output.get_shape()) 728 729 result = output.eval() 730 self.assertAllEqual([0, 1, -1], result) 731 732 def testBasicBool(self): 733 with self.test_session(): 734 keys = constant_op.constant([11, 12, 13], dtypes.int64) 735 values = constant_op.constant([True, True, True], dtypes.bool) 736 table = lookup.MutableDenseHashTable( 737 dtypes.int64, dtypes.bool, default_value=False, empty_key=0) 738 self.assertAllEqual(0, table.size().eval()) 739 740 table.insert(keys, values).run() 741 self.assertAllEqual(3, table.size().eval()) 742 743 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 744 output = table.lookup(input_string) 745 self.assertAllEqual([3], output.get_shape()) 746 747 result = output.eval() 748 self.assertAllEqual([True, True, False], result) 749 750 def testLookupUnknownShape(self): 751 with self.test_session(): 752 keys = constant_op.constant([11, 12, 13], dtypes.int64) 753 values = constant_op.constant([0, 1, 2], dtypes.int64) 754 table = lookup.MutableDenseHashTable( 755 dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) 756 757 table.insert(keys, values).run() 758 self.assertAllEqual(3, table.size().eval()) 759 760 placeholder_keys = array_ops.placeholder(dtypes.int64) 761 output = table.lookup(placeholder_keys) 762 self.assertAllEqual(None, output.get_shape()) 763 result = output.eval({placeholder_keys: [11, 12, 15]}) 764 self.assertAllEqual([0, 1, -1], result) 765 766 def testMapStringToFloat(self): 767 with self.test_session(): 768 keys = constant_op.constant(["a", "b", "c"], dtypes.string) 769 values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) 770 default_value = constant_op.constant(-1.5, dtypes.float32) 771 table = lookup.MutableDenseHashTable( 772 dtypes.string, 773 dtypes.float32, 774 default_value=default_value, 775 empty_key="") 776 self.assertAllEqual(0, table.size().eval()) 777 778 table.insert(keys, values).run() 779 self.assertAllEqual(3, table.size().eval()) 780 781 input_string = constant_op.constant(["a", "b", "d"], dtypes.string) 782 output = table.lookup(input_string) 783 self.assertAllEqual([3], output.get_shape()) 784 785 result = output.eval() 786 self.assertAllClose([0, 1.1, -1.5], result) 787 788 def testMapInt64ToFloat(self): 789 for float_dtype in [dtypes.float32, dtypes.float64]: 790 with self.test_session(): 791 keys = constant_op.constant([11, 12, 13], dtypes.int64) 792 values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) 793 default_value = constant_op.constant(-1.5, float_dtype) 794 table = lookup.MutableDenseHashTable( 795 dtypes.int64, float_dtype, default_value=default_value, empty_key=0) 796 self.assertAllEqual(0, table.size().eval()) 797 798 table.insert(keys, values).run() 799 self.assertAllEqual(3, table.size().eval()) 800 801 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 802 output = table.lookup(input_string) 803 self.assertAllEqual([3], output.get_shape()) 804 805 result = output.eval() 806 self.assertAllClose([0, 1.1, -1.5], result) 807 808 def testVectorValues(self): 809 with self.test_session(): 810 keys = constant_op.constant([11, 12, 13], dtypes.int64) 811 values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], 812 dtypes.int64) 813 default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) 814 table = lookup.MutableDenseHashTable( 815 dtypes.int64, 816 dtypes.int64, 817 default_value=default_value, 818 empty_key=0, 819 initial_num_buckets=4) 820 self.assertAllEqual(0, table.size().eval()) 821 822 table.insert(keys, values).run() 823 self.assertAllEqual(3, table.size().eval()) 824 self.assertAllEqual(4, len(table.export()[0].eval())) 825 826 table.insert( 827 constant_op.constant([14], dtypes.int64), 828 constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run() 829 self.assertAllEqual(4, table.size().eval()) 830 self.assertAllEqual(8, len(table.export()[0].eval())) 831 832 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 833 output = table.lookup(input_string) 834 self.assertAllEqual([3, 4], output.get_shape()) 835 836 result = output.eval() 837 self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], 838 result) 839 840 def testVectorKeys(self): 841 with self.test_session(): 842 keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) 843 values = constant_op.constant([10, 11, 12], dtypes.int64) 844 empty_key = constant_op.constant([0, 3], dtypes.int64) 845 default_value = constant_op.constant(-1, dtypes.int64) 846 table = lookup.MutableDenseHashTable( 847 dtypes.int64, 848 dtypes.int64, 849 default_value=default_value, 850 empty_key=empty_key, 851 initial_num_buckets=8) 852 self.assertAllEqual(0, table.size().eval()) 853 854 table.insert(keys, values).run() 855 self.assertAllEqual(3, table.size().eval()) 856 857 table.insert( 858 constant_op.constant([[0, 0]], dtypes.int64), 859 constant_op.constant([13], dtypes.int64)).run() 860 self.assertAllEqual(4, table.size().eval()) 861 self.assertAllEqual(8, len(table.export()[0].eval())) 862 863 input_string = constant_op.constant([[0, 1], [1, 2], [0, 2]], 864 dtypes.int64) 865 output = table.lookup(input_string) 866 self.assertAllEqual([3], output.get_shape()) 867 868 result = output.eval() 869 self.assertAllEqual([10, 11, -1], result) 870 871 def testResize(self): 872 with self.test_session(): 873 keys = constant_op.constant([11, 12, 13], dtypes.int64) 874 values = constant_op.constant([0, 1, 2], dtypes.int64) 875 table = lookup.MutableDenseHashTable( 876 dtypes.int64, 877 dtypes.int64, 878 default_value=-1, 879 empty_key=0, 880 initial_num_buckets=4) 881 self.assertAllEqual(0, table.size().eval()) 882 883 table.insert(keys, values).run() 884 self.assertAllEqual(3, table.size().eval()) 885 self.assertAllEqual(4, len(table.export()[0].eval())) 886 887 keys2 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) 888 values2 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) 889 890 table.insert(keys2, values2).run() 891 self.assertAllEqual(7, table.size().eval()) 892 self.assertAllEqual(16, len(table.export()[0].eval())) 893 894 keys3 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], 895 dtypes.int64) 896 output = table.lookup(keys3) 897 self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval()) 898 899 def testExport(self): 900 with self.test_session(): 901 keys = constant_op.constant([11, 12, 13], dtypes.int64) 902 values = constant_op.constant([1, 2, 3], dtypes.int64) 903 table = lookup.MutableDenseHashTable( 904 dtypes.int64, 905 dtypes.int64, 906 default_value=-1, 907 empty_key=100, 908 initial_num_buckets=8) 909 self.assertAllEqual(0, table.size().eval()) 910 911 table.insert(keys, values).run() 912 self.assertAllEqual(3, table.size().eval()) 913 914 exported_keys, exported_values = table.export() 915 self.assertAllEqual([None], exported_keys.get_shape().as_list()) 916 self.assertAllEqual([None], exported_values.get_shape().as_list()) 917 918 np_keys = exported_keys.eval() 919 np_values = exported_values.eval() 920 921 self.assertAllEqual(8, len(np_keys)) 922 self.assertAllEqual(8, len(np_values)) 923 924 # pair up keys and values, drop extra added dimension 925 pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] 926 # sort by key 927 pairs = pairs[pairs[:, 0].argsort()] 928 self.assertAllEqual([[11, 1], [12, 2], [13, 3], [100, 0], [100, 0], 929 [100, 0], [100, 0], [100, 0]], pairs) 930 931 def testSaveRestore(self): 932 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 933 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 934 935 with self.test_session(graph=ops.Graph()) as sess: 936 default_value = -1 937 empty_key = 0 938 keys = constant_op.constant([11, 12, 13], dtypes.int64) 939 values = constant_op.constant([0, 1, 2], dtypes.int64) 940 table = lookup.MutableDenseHashTable( 941 dtypes.int64, 942 dtypes.int64, 943 default_value=default_value, 944 empty_key=empty_key, 945 name="t1", 946 checkpoint=True, 947 initial_num_buckets=32) 948 949 save = saver.Saver() 950 951 self.assertAllEqual(0, table.size().eval()) 952 table.insert(keys, values).run() 953 self.assertAllEqual(3, table.size().eval()) 954 self.assertAllEqual(32, len(table.export()[0].eval())) 955 956 val = save.save(sess, save_path) 957 self.assertTrue(isinstance(val, six.string_types)) 958 self.assertEqual(save_path, val) 959 960 with self.test_session(graph=ops.Graph()) as sess: 961 table = lookup.MutableDenseHashTable( 962 dtypes.int64, 963 dtypes.int64, 964 default_value=default_value, 965 empty_key=empty_key, 966 name="t1", 967 checkpoint=True, 968 initial_num_buckets=64) 969 table.insert( 970 constant_op.constant([11, 14], dtypes.int64), 971 constant_op.constant([12, 24], dtypes.int64)).run() 972 self.assertAllEqual(2, table.size().eval()) 973 self.assertAllEqual(64, len(table.export()[0].eval())) 974 975 save = saver.Saver() 976 977 # Restore the saved values in the parameter nodes. 978 save.restore(sess, save_path) 979 980 self.assertAllEqual(3, table.size().eval()) 981 self.assertAllEqual(32, len(table.export()[0].eval())) 982 983 input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) 984 output = table.lookup(input_string) 985 self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) 986 987 def testVectorSaveRestore(self): 988 save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") 989 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 990 991 with self.test_session(graph=ops.Graph()) as sess: 992 empty_key = constant_op.constant([11, 13], dtypes.int64) 993 default_value = constant_op.constant([-1, -2], dtypes.int64) 994 keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) 995 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 996 table = lookup.MutableDenseHashTable( 997 dtypes.int64, 998 dtypes.int64, 999 default_value=default_value, 1000 empty_key=empty_key, 1001 name="t1", 1002 checkpoint=True, 1003 initial_num_buckets=32) 1004 1005 save = saver.Saver() 1006 1007 self.assertAllEqual(0, table.size().eval()) 1008 table.insert(keys, values).run() 1009 self.assertAllEqual(3, table.size().eval()) 1010 self.assertAllEqual(32, len(table.export()[0].eval())) 1011 1012 val = save.save(sess, save_path) 1013 self.assertTrue(isinstance(val, six.string_types)) 1014 self.assertEqual(save_path, val) 1015 1016 with self.test_session(graph=ops.Graph()) as sess: 1017 empty_key = constant_op.constant([11, 13], dtypes.int64) 1018 default_value = constant_op.constant([-1, -2], dtypes.int64) 1019 table = lookup.MutableDenseHashTable( 1020 dtypes.int64, 1021 dtypes.int64, 1022 default_value=default_value, 1023 empty_key=empty_key, 1024 name="t1", 1025 checkpoint=True, 1026 initial_num_buckets=64) 1027 table.insert( 1028 constant_op.constant([[11, 12], [13, 15]], dtypes.int64), 1029 constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run() 1030 self.assertAllEqual(2, table.size().eval()) 1031 self.assertAllEqual(64, len(table.export()[0].eval())) 1032 1033 save = saver.Saver() 1034 1035 # Restore the saved values in the parameter nodes. 1036 save.restore(sess, save_path) 1037 1038 self.assertAllEqual(3, table.size().eval()) 1039 self.assertAllEqual(32, len(table.export()[0].eval())) 1040 1041 input_string = constant_op.constant( 1042 [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) 1043 output = table.lookup(input_string) 1044 self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], 1045 output.eval()) 1046 1047 def testVectorScalarSaveRestore(self): 1048 save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") 1049 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 1050 1051 with self.test_session(graph=ops.Graph()) as sess: 1052 empty_key = constant_op.constant([11, 13], dtypes.int64) 1053 default_value = constant_op.constant(-1, dtypes.int64) 1054 keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) 1055 values = constant_op.constant([0, 1, 2], dtypes.int64) 1056 table = lookup.MutableDenseHashTable( 1057 dtypes.int64, 1058 dtypes.int64, 1059 default_value=default_value, 1060 empty_key=empty_key, 1061 name="t2", 1062 checkpoint=True, 1063 initial_num_buckets=32) 1064 1065 save = saver.Saver() 1066 1067 self.assertAllEqual(0, table.size().eval()) 1068 table.insert(keys, values).run() 1069 self.assertAllEqual(3, table.size().eval()) 1070 self.assertAllEqual(32, len(table.export()[0].eval())) 1071 1072 val = save.save(sess, save_path) 1073 self.assertTrue(isinstance(val, six.string_types)) 1074 self.assertEqual(save_path, val) 1075 1076 with self.test_session(graph=ops.Graph()) as sess: 1077 empty_key = constant_op.constant([11, 13], dtypes.int64) 1078 default_value = constant_op.constant(-1, dtypes.int64) 1079 table = lookup.MutableDenseHashTable( 1080 dtypes.int64, 1081 dtypes.int64, 1082 default_value=default_value, 1083 empty_key=empty_key, 1084 name="t2", 1085 checkpoint=True, 1086 initial_num_buckets=64) 1087 table.insert( 1088 constant_op.constant([[11, 12], [13, 15]], dtypes.int64), 1089 constant_op.constant([3, 4], dtypes.int64)).run() 1090 self.assertAllEqual(2, table.size().eval()) 1091 self.assertAllEqual(64, len(table.export()[0].eval())) 1092 1093 save = saver.Saver() 1094 1095 # Restore the saved values in the parameter nodes. 1096 save.restore(sess, save_path) 1097 1098 self.assertAllEqual(3, table.size().eval()) 1099 self.assertAllEqual(32, len(table.export()[0].eval())) 1100 1101 input_string = constant_op.constant( 1102 [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) 1103 output = table.lookup(input_string) 1104 self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) 1105 1106 def testReprobe(self): 1107 with self.test_session(): 1108 # Insert 6 keys into a table with 8 buckets. 1109 # The values are chosen to make sure collisions occur when using GCC STL 1110 keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) 1111 values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) 1112 table = lookup.MutableDenseHashTable( 1113 dtypes.int64, 1114 dtypes.int64, 1115 default_value=-1, 1116 empty_key=0, 1117 initial_num_buckets=8) 1118 self.assertAllEqual(0, table.size().eval()) 1119 1120 table.insert(keys, values).run() 1121 self.assertAllEqual(6, table.size().eval()) 1122 1123 input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], 1124 dtypes.int64) 1125 output = table.lookup(input_string) 1126 self.assertAllEqual([9], output.get_shape()) 1127 1128 result = output.eval() 1129 self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) 1130 1131 def testCustomEmptyKey(self): 1132 with self.test_session(): 1133 keys = constant_op.constant([11, 0, 13], dtypes.int64) 1134 values = constant_op.constant([0, 1, 2], dtypes.int64) 1135 table = lookup.MutableDenseHashTable( 1136 dtypes.int64, dtypes.int64, default_value=-1, empty_key=12) 1137 self.assertAllEqual(0, table.size().eval()) 1138 1139 table.insert(keys, values).run() 1140 self.assertAllEqual(3, table.size().eval()) 1141 1142 input_string = constant_op.constant([11, 0, 15], dtypes.int64) 1143 output = table.lookup(input_string) 1144 self.assertAllEqual([3], output.get_shape()) 1145 1146 result = output.eval() 1147 self.assertAllEqual([0, 1, -1], result) 1148 1149 def testErrors(self): 1150 with self.test_session(): 1151 table = lookup.MutableDenseHashTable( 1152 dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) 1153 1154 # Inserting the empty key returns an error 1155 keys = constant_op.constant([11, 0], dtypes.int64) 1156 values = constant_op.constant([0, 1], dtypes.int64) 1157 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1158 "empty_key"): 1159 table.insert(keys, values).run() 1160 1161 # Looking up the empty key returns an error 1162 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1163 "empty_key"): 1164 table.lookup(keys).eval() 1165 1166 # Arbitrary tensors of keys are not supported 1167 keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) 1168 values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) 1169 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1170 "Expected key shape"): 1171 table.lookup(keys).eval() 1172 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1173 "Expected key shape"): 1174 table.insert(keys, values).run() 1175 1176 table2 = lookup.MutableDenseHashTable( 1177 dtypes.int64, 1178 dtypes.int64, 1179 default_value=-1, 1180 empty_key=17, 1181 initial_num_buckets=12) 1182 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1183 "Number of buckets must be"): 1184 self.assertAllEqual(0, table2.size().eval()) 1185 1186 1187 class IndexTableFromFile(test.TestCase): 1188 1189 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 1190 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 1191 with open(vocabulary_file, "w") as f: 1192 f.write("\n".join(values) + "\n") 1193 return vocabulary_file 1194 1195 def test_string_index_table_from_file(self): 1196 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 1197 with self.test_session(): 1198 table = lookup.index_table_from_file( 1199 vocabulary_file=vocabulary_file, num_oov_buckets=1) 1200 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1201 1202 self.assertRaises(errors_impl.OpError, ids.eval) 1203 lookup_ops.tables_initializer().run() 1204 self.assertAllEqual((1, 2, 3), ids.eval()) 1205 1206 def test_string_index_table_from_file_tensor_filename(self): 1207 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 1208 with self.test_session(): 1209 vocabulary_file = constant_op.constant(vocabulary_file) 1210 table = lookup.index_table_from_file( 1211 vocabulary_file=vocabulary_file, num_oov_buckets=1) 1212 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1213 1214 self.assertRaises(errors_impl.OpError, ids.eval) 1215 lookup_ops.tables_initializer().run() 1216 self.assertAllEqual((1, 2, 3), ids.eval()) 1217 self.assertEqual(1, 1218 len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) 1219 1220 def test_string_index_table_from_file_placeholder_filename(self): 1221 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 1222 with self.test_session(): 1223 vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) 1224 table = lookup.index_table_from_file( 1225 vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) 1226 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1227 1228 self.assertRaises(errors_impl.OpError, ids.eval) 1229 1230 feed_dict = {vocabulary_placeholder.name: vocabulary_file} 1231 lookup_ops.tables_initializer().run(feed_dict=feed_dict) 1232 self.assertAllEqual((1, 2, 3), ids.eval()) 1233 self.assertEqual(0, 1234 len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) 1235 1236 def test_int32_index_table_from_file(self): 1237 vocabulary_file = self._createVocabFile( 1238 "f2i_vocab2.txt", values=("42", "1", "-1000")) 1239 with self.test_session(): 1240 table = lookup.index_table_from_file( 1241 vocabulary_file=vocabulary_file, num_oov_buckets=1, 1242 key_dtype=dtypes.int32) 1243 ids = table.lookup( 1244 constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) 1245 1246 self.assertRaises(errors_impl.OpError, ids.eval) 1247 lookup_ops.tables_initializer().run() 1248 self.assertAllEqual((1, 2, 3), ids.eval()) 1249 1250 def test_int64_index_table_from_file(self): 1251 vocabulary_file = self._createVocabFile( 1252 "f2i_vocab3.txt", values=("42", "1", "-1000")) 1253 with self.test_session(): 1254 table = lookup.index_table_from_file( 1255 vocabulary_file=vocabulary_file, num_oov_buckets=1, 1256 key_dtype=dtypes.int64) 1257 ids = table.lookup( 1258 constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) 1259 1260 self.assertRaises(errors_impl.OpError, ids.eval) 1261 lookup_ops.tables_initializer().run() 1262 self.assertAllEqual((1, 2, 3), ids.eval()) 1263 1264 def test_index_table_from_file_with_default_value(self): 1265 default_value = -42 1266 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") 1267 with self.test_session(): 1268 table = lookup.index_table_from_file( 1269 vocabulary_file=vocabulary_file, default_value=default_value) 1270 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1271 1272 self.assertRaises(errors_impl.OpError, ids.eval) 1273 lookup_ops.tables_initializer().run() 1274 self.assertAllEqual((1, 2, default_value), ids.eval()) 1275 1276 def test_index_table_from_file_with_oov_buckets(self): 1277 vocabulary_file = self._createVocabFile("f2i_vocab5.txt") 1278 with self.test_session(): 1279 table = lookup.index_table_from_file( 1280 vocabulary_file=vocabulary_file, num_oov_buckets=1000) 1281 ids = table.lookup( 1282 constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) 1283 1284 self.assertRaises(errors_impl.OpError, ids.eval) 1285 lookup_ops.tables_initializer().run() 1286 self.assertAllEqual( 1287 ( 1288 1, # From vocabulary file. 1289 2, # From vocabulary file. 1290 867, # 3 + fingerprint("tarkus") mod 300. 1291 860), # 3 + fingerprint("toccata") mod 300. 1292 ids.eval()) 1293 1294 def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): 1295 self.assertRaises( 1296 ValueError, 1297 lookup.index_table_from_file, 1298 vocabulary_file="") 1299 1300 def test_index_table_from_file_fails_with_empty_vocabulary(self): 1301 self.assertRaises( 1302 ValueError, 1303 lookup.index_table_from_file, 1304 vocabulary_file=None) 1305 1306 def test_index_table_from_file_with_vocab_size_too_small(self): 1307 vocabulary_file = self._createVocabFile("f2i_vocab6.txt") 1308 with self.test_session(): 1309 table = lookup.index_table_from_file( 1310 vocabulary_file=vocabulary_file, vocab_size=2) 1311 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1312 1313 self.assertRaises(errors_impl.OpError, ids.eval) 1314 lookup_ops.tables_initializer().run() 1315 self.assertAllEqual((1, -1, -1), ids.eval()) 1316 self.assertEqual(2, table.size().eval()) 1317 1318 def test_index_table_from_file_with_vocab_size_too_large(self): 1319 vocabulary_file = self._createVocabFile("f2i_vocab7.txt") 1320 with self.test_session(): 1321 table = lookup.index_table_from_file( 1322 vocabulary_file=vocabulary_file, vocab_size=4) 1323 self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1324 "Invalid vocab_size", table.init.run) 1325 1326 def test_index_table_from_file_with_vocab_size(self): 1327 vocabulary_file = self._createVocabFile("f2i_vocab8.txt") 1328 1329 self.assertRaises( 1330 ValueError, 1331 lookup.index_table_from_file, 1332 vocabulary_file=vocabulary_file, 1333 vocab_size=0) 1334 1335 with self.test_session(): 1336 table = lookup.index_table_from_file( 1337 vocabulary_file=vocabulary_file, vocab_size=3) 1338 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1339 1340 self.assertRaises(errors_impl.OpError, ids.eval) 1341 lookup_ops.tables_initializer().run() 1342 self.assertAllEqual((1, 2, -1), ids.eval()) 1343 self.assertEqual(3, table.size().eval()) 1344 1345 def test_index_table_from_file_with_invalid_hashers(self): 1346 vocabulary_file = self._createVocabFile("invalid_hasher.txt") 1347 with self.test_session(): 1348 with self.assertRaises(TypeError): 1349 lookup.index_table_from_file( 1350 vocabulary_file=vocabulary_file, 1351 vocab_size=3, 1352 num_oov_buckets=1, 1353 hasher_spec=1) 1354 1355 table = lookup.index_table_from_file( 1356 vocabulary_file=vocabulary_file, 1357 vocab_size=3, 1358 num_oov_buckets=1, 1359 hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) 1360 1361 self.assertRaises(ValueError, table.lookup, 1362 constant_op.constant(["salad", "surgery", "tarkus"])) 1363 1364 1365 class KeyValueTensorInitializerTest(test.TestCase): 1366 1367 def test_string(self): 1368 with ops.Graph().as_default(), self.test_session(): 1369 init = lookup.KeyValueTensorInitializer( 1370 ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) 1371 table = lookup.HashTable(init, default_value=-1) 1372 table.init.run() 1373 1374 def test_int64(self): 1375 with ops.Graph().as_default(), self.test_session(): 1376 init = lookup.KeyValueTensorInitializer( 1377 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) 1378 table = lookup.HashTable(init, default_value=-1) 1379 table.init.run() 1380 1381 def test_int32(self): 1382 with ops.Graph().as_default(), self.test_session(): 1383 init = lookup.KeyValueTensorInitializer( 1384 (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64) 1385 table = lookup.HashTable(init, default_value=-1) 1386 with self.assertRaisesRegexp( 1387 errors_impl.OpError, "No OpKernel was registered"): 1388 table.init.run() 1389 1390 1391 class IndexTableFromTensor(test.TestCase): 1392 1393 def test_index_table_from_tensor_with_tensor_init(self): 1394 with self.test_session(): 1395 table = lookup.index_table_from_tensor( 1396 mapping=("brain", "salad", "surgery"), num_oov_buckets=1) 1397 ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) 1398 1399 self.assertRaises(errors_impl.OpError, ids.eval) 1400 lookup_ops.tables_initializer().run() 1401 self.assertAllEqual((1, 2, 3), ids.eval()) 1402 1403 def test_int32_index_table_from_tensor_with_tensor_init(self): 1404 with self.test_session(): 1405 table = lookup.index_table_from_tensor( 1406 mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) 1407 ids = table.lookup( 1408 constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) 1409 1410 self.assertRaises(errors_impl.OpError, ids.eval) 1411 lookup_ops.tables_initializer().run() 1412 self.assertAllEqual((1, 2, 3), ids.eval()) 1413 1414 def test_int64_index_table_from_tensor_with_tensor_init(self): 1415 with self.test_session(): 1416 table = lookup.index_table_from_tensor( 1417 mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) 1418 ids = table.lookup( 1419 constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) 1420 1421 self.assertRaises(errors_impl.OpError, ids.eval) 1422 lookup_ops.tables_initializer().run() 1423 self.assertAllEqual((1, 2, 3), ids.eval()) 1424 1425 def test_index_table_from_tensor_with_default_value(self): 1426 default_value = -42 1427 with self.test_session(): 1428 table = lookup.index_table_from_tensor( 1429 mapping=["brain", "salad", "surgery"], default_value=default_value) 1430 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 1431 1432 self.assertRaises(errors_impl.OpError, ids.eval) 1433 lookup_ops.tables_initializer().run() 1434 self.assertAllEqual((1, 2, default_value), ids.eval()) 1435 1436 def test_index_table_from_tensor_missing_mapping(self): 1437 with self.test_session(): 1438 with self.assertRaisesRegexp(ValueError, "mapping must be specified"): 1439 lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1) 1440 1441 def test_index_table_from_tensor_empty_mapping(self): 1442 with self.test_session(): 1443 table = lookup.index_table_from_tensor( 1444 mapping=np.array([], dtype=np.str_), num_oov_buckets=1) 1445 ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) 1446 self.assertRaises(errors_impl.OpError, ids.eval) 1447 with self.assertRaisesRegexp( 1448 errors_impl.OpError, "keys and values cannot be empty"): 1449 lookup_ops.tables_initializer().run() 1450 1451 def test_index_table_from_tensor_with_invalid_hashers(self): 1452 with self.test_session(): 1453 with self.assertRaises(TypeError): 1454 lookup.index_table_from_tensor( 1455 mapping=["brain", "salad", "surgery"], 1456 num_oov_buckets=1, 1457 hasher_spec=1) 1458 1459 table = lookup.index_table_from_tensor( 1460 mapping=["brain", "salad", "surgery"], 1461 num_oov_buckets=1, 1462 hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) 1463 1464 self.assertRaises(ValueError, table.lookup, 1465 constant_op.constant(["salad", "surgery", "tarkus"])) 1466 1467 1468 class StringToIndexTest(test.TestCase): 1469 1470 def test_string_to_index(self): 1471 with self.test_session(): 1472 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1473 feats = constant_op.constant(["salad", "surgery", "tarkus"]) 1474 indices = lookup.string_to_index(feats, mapping=mapping_strings) 1475 1476 self.assertRaises(errors_impl.OpError, indices.eval) 1477 lookup_ops.tables_initializer().run() 1478 1479 self.assertAllEqual((1, 2, -1), indices.eval()) 1480 1481 def test_duplicate_entries(self): 1482 with self.test_session(): 1483 mapping_strings = constant_op.constant(["hello", "hello"]) 1484 feats = constant_op.constant(["hello", "hola"]) 1485 _ = lookup.string_to_index(feats, mapping=mapping_strings) 1486 1487 self.assertRaises(errors_impl.OpError, 1488 lookup_ops.tables_initializer().run) 1489 1490 def test_string_to_index_with_default_value(self): 1491 default_value = -42 1492 with self.test_session(): 1493 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1494 feats = constant_op.constant(["salad", "surgery", "tarkus"]) 1495 indices = lookup.string_to_index( 1496 feats, mapping=mapping_strings, default_value=default_value) 1497 self.assertRaises(errors_impl.OpError, indices.eval) 1498 1499 lookup_ops.tables_initializer().run() 1500 self.assertAllEqual((1, 2, default_value), indices.eval()) 1501 1502 1503 class IndexToStringTableFromFileTest(test.TestCase): 1504 1505 def _createVocabFile(self, basename): 1506 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 1507 with open(vocabulary_file, "w") as f: 1508 f.write("\n".join(["brain", "salad", "surgery"]) + "\n") 1509 return vocabulary_file 1510 1511 def test_index_to_string_table(self): 1512 vocabulary_file = self._createVocabFile("i2f_vocab1.txt") 1513 with self.test_session(): 1514 table = lookup.index_to_string_table_from_file( 1515 vocabulary_file=vocabulary_file) 1516 features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) 1517 self.assertRaises(errors_impl.OpError, features.eval) 1518 lookup_ops.tables_initializer().run() 1519 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 1520 features.eval()) 1521 1522 def test_index_to_string_table_with_default_value(self): 1523 default_value = b"NONE" 1524 vocabulary_file = self._createVocabFile("f2i_vocab2.txt") 1525 with self.test_session(): 1526 table = lookup.index_to_string_table_from_file( 1527 vocabulary_file=vocabulary_file, default_value=default_value) 1528 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 1529 self.assertRaises(errors_impl.OpError, features.eval) 1530 lookup_ops.tables_initializer().run() 1531 self.assertAllEqual((b"salad", b"surgery", default_value), 1532 features.eval()) 1533 1534 def test_index_to_string_table_with_vocab_size_too_small(self): 1535 default_value = b"NONE" 1536 vocabulary_file = self._createVocabFile("f2i_vocab2.txt") 1537 with self.test_session(): 1538 table = lookup.index_to_string_table_from_file( 1539 vocabulary_file=vocabulary_file, 1540 vocab_size=2, 1541 default_value=default_value) 1542 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 1543 self.assertRaises(errors_impl.OpError, features.eval) 1544 lookup_ops.tables_initializer().run() 1545 self.assertAllEqual((b"salad", default_value, default_value), 1546 features.eval()) 1547 1548 def test_index_to_string_table_with_vocab_size_too_large(self): 1549 vocabulary_file = self._createVocabFile("f2i_vocab6.txt") 1550 with self.test_session(): 1551 table = lookup.index_to_string_table_from_file( 1552 vocabulary_file=vocabulary_file, vocab_size=4) 1553 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 1554 1555 self.assertRaises(errors_impl.OpError, features.eval) 1556 init = lookup_ops.tables_initializer() 1557 self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 1558 "Invalid vocab_size", init.run) 1559 1560 def test_index_to_string_table_with_vocab_size(self): 1561 vocabulary_file = self._createVocabFile("f2i_vocab7.txt") 1562 with self.test_session(): 1563 table = lookup.index_to_string_table_from_file( 1564 vocabulary_file=vocabulary_file, vocab_size=3) 1565 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 1566 1567 self.assertRaises(errors_impl.OpError, features.eval) 1568 lookup_ops.tables_initializer().run() 1569 self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) 1570 1571 1572 class IndexToStringTableFromTensorTest(test.TestCase): 1573 1574 def test_index_to_string_table_from_tensor(self): 1575 with self.test_session(): 1576 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1577 table = lookup.index_to_string_table_from_tensor( 1578 mapping=mapping_strings) 1579 1580 indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1581 features = table.lookup(indices) 1582 self.assertRaises(errors_impl.OpError, features.eval) 1583 lookup_ops.tables_initializer().run() 1584 1585 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 1586 features.eval()) 1587 1588 def test_duplicate_entries(self): 1589 with self.test_session(): 1590 mapping_strings = constant_op.constant(["hello", "hello"]) 1591 table = lookup.index_to_string_table_from_tensor( 1592 mapping=mapping_strings) 1593 indices = constant_op.constant([0, 1, 4], dtypes.int64) 1594 features = table.lookup(indices) 1595 lookup_ops.tables_initializer().run() 1596 self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) 1597 1598 def test_index_to_string_with_default_value(self): 1599 default_value = b"NONE" 1600 with self.test_session(): 1601 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1602 table = lookup.index_to_string_table_from_tensor( 1603 mapping=mapping_strings, default_value=default_value) 1604 indices = constant_op.constant([1, 2, 4], dtypes.int64) 1605 features = table.lookup(indices) 1606 self.assertRaises(errors_impl.OpError, features.eval) 1607 1608 lookup_ops.tables_initializer().run() 1609 self.assertAllEqual((b"salad", b"surgery", default_value), 1610 features.eval()) 1611 1612 1613 class IndexToStringTest(test.TestCase): 1614 1615 def test_index_to_string(self): 1616 with self.test_session(): 1617 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1618 indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1619 feats = lookup.index_to_string(indices, mapping=mapping_strings) 1620 1621 self.assertRaises(errors_impl.OpError, feats.eval) 1622 lookup_ops.tables_initializer().run() 1623 1624 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 1625 feats.eval()) 1626 1627 def test_duplicate_entries(self): 1628 with self.test_session(): 1629 mapping_strings = constant_op.constant(["hello", "hello"]) 1630 indices = constant_op.constant([0, 1, 4], dtypes.int64) 1631 feats = lookup.index_to_string(indices, mapping=mapping_strings) 1632 lookup_ops.tables_initializer().run() 1633 self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) 1634 1635 self.assertRaises(errors_impl.OpError, 1636 lookup_ops.tables_initializer().run) 1637 1638 def test_index_to_string_with_default_value(self): 1639 default_value = b"NONE" 1640 with self.test_session(): 1641 mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) 1642 indices = constant_op.constant([1, 2, 4], dtypes.int64) 1643 feats = lookup.index_to_string( 1644 indices, mapping=mapping_strings, default_value=default_value) 1645 self.assertRaises(errors_impl.OpError, feats.eval) 1646 1647 lookup_ops.tables_initializer().run() 1648 self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) 1649 1650 1651 class InitializeTableFromFileOpTest(test.TestCase): 1652 1653 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 1654 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 1655 with open(vocabulary_file, "w") as f: 1656 f.write("\n".join(values) + "\n") 1657 return vocabulary_file 1658 1659 @test_util.run_in_graph_and_eager_modes() 1660 def testInitializeStringTable(self): 1661 vocabulary_file = self._createVocabFile("one_column_1.txt") 1662 default_value = -1 1663 table = lookup.HashTable( 1664 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1665 lookup.TextFileIndex.WHOLE_LINE, 1666 dtypes.int64, 1667 lookup.TextFileIndex.LINE_NUMBER), 1668 default_value) 1669 self.evaluate(table.init) 1670 1671 output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) 1672 1673 result = self.evaluate(output) 1674 self.assertAllEqual([0, 1, -1], result) 1675 1676 def testInitializeInt64Table(self): 1677 vocabulary_file = self._createVocabFile( 1678 "one_column_int64.txt", values=("42", "1", "-1000")) 1679 1680 with self.test_session(): 1681 default_value = -1 1682 table = lookup.HashTable( 1683 lookup.TextFileInitializer(vocabulary_file, dtypes.int64, 1684 lookup.TextFileIndex.WHOLE_LINE, 1685 dtypes.int64, 1686 lookup.TextFileIndex.LINE_NUMBER), 1687 default_value) 1688 table.init.run() 1689 1690 output = table.lookup( 1691 constant_op.constant((42, 1, 11), dtype=dtypes.int64)) 1692 1693 result = output.eval() 1694 self.assertAllEqual([0, 1, -1], result) 1695 1696 def testInitializeIndexTable(self): 1697 vocabulary_file = self._createVocabFile("one_column_2.txt") 1698 1699 with self.test_session(): 1700 default_value = "UNK" 1701 key_index = lookup.TextFileIndex.LINE_NUMBER 1702 value_index = lookup.TextFileIndex.WHOLE_LINE 1703 table = lookup.HashTable( 1704 lookup.TextFileInitializer(vocabulary_file, dtypes.int64, 1705 key_index, dtypes.string, value_index), 1706 default_value) 1707 table.init.run() 1708 1709 input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1710 output = table.lookup(input_values) 1711 1712 result = output.eval() 1713 self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result) 1714 1715 def testMultiColumn(self): 1716 vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt") 1717 with open(vocabulary_file, "w") as f: 1718 f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") 1719 1720 with self.test_session(): 1721 default_value = -1 1722 key_index = 1 1723 value_index = 2 1724 1725 table = lookup.HashTable( 1726 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1727 key_index, dtypes.int64, value_index), 1728 default_value) 1729 table.init.run() 1730 1731 input_string = constant_op.constant(["brain", "salad", "surgery"]) 1732 output = table.lookup(input_string) 1733 1734 result = output.eval() 1735 self.assertAllEqual([1, 5, 6], result) 1736 1737 def testInvalidDataTypeInMultiColumn(self): 1738 vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt") 1739 with open(vocabulary_file, "w") as f: 1740 f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") 1741 1742 with self.test_session(): 1743 default_value = -1 1744 key_index = 2 1745 value_index = 1 1746 table = lookup.HashTable( 1747 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1748 key_index, dtypes.int64, value_index), 1749 default_value) 1750 with self.assertRaisesOpError("is not a valid"): 1751 table.init.run() 1752 1753 def testInvalidDataType(self): 1754 vocabulary_file = self._createVocabFile("one_column_3.txt") 1755 1756 with self.test_session(): 1757 default_value = "UNK" 1758 key_index = lookup.TextFileIndex.WHOLE_LINE 1759 value_index = lookup.TextFileIndex.LINE_NUMBER 1760 1761 with self.assertRaises(ValueError): 1762 lookup.HashTable( 1763 lookup.TextFileInitializer(vocabulary_file, dtypes.int64, 1764 key_index, dtypes.string, 1765 value_index), default_value) 1766 1767 def testInvalidIndex(self): 1768 vocabulary_file = self._createVocabFile("one_column_4.txt") 1769 with self.test_session(): 1770 default_value = -1 1771 key_index = 1 # second column of the line 1772 value_index = lookup.TextFileIndex.LINE_NUMBER 1773 table = lookup.HashTable( 1774 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1775 key_index, dtypes.int64, value_index), 1776 default_value) 1777 1778 with self.assertRaisesOpError("Invalid number of columns"): 1779 table.init.run() 1780 1781 def testInitializeSameTableWithMultipleNodes(self): 1782 vocabulary_file = self._createVocabFile("one_column_5.txt") 1783 1784 with self.test_session() as sess: 1785 shared_name = "shared-one-columm" 1786 default_value = -1 1787 table1 = lookup.HashTable( 1788 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1789 lookup.TextFileIndex.WHOLE_LINE, 1790 dtypes.int64, 1791 lookup.TextFileIndex.LINE_NUMBER), 1792 default_value, 1793 shared_name=shared_name) 1794 table2 = lookup.HashTable( 1795 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1796 lookup.TextFileIndex.WHOLE_LINE, 1797 dtypes.int64, 1798 lookup.TextFileIndex.LINE_NUMBER), 1799 default_value, 1800 shared_name=shared_name) 1801 table3 = lookup.HashTable( 1802 lookup.TextFileInitializer(vocabulary_file, dtypes.string, 1803 lookup.TextFileIndex.WHOLE_LINE, 1804 dtypes.int64, 1805 lookup.TextFileIndex.LINE_NUMBER), 1806 default_value, 1807 shared_name=shared_name) 1808 1809 lookup_ops.tables_initializer().run() 1810 1811 input_string = constant_op.constant(["brain", "salad", "tank"]) 1812 1813 output1 = table1.lookup(input_string) 1814 output2 = table2.lookup(input_string) 1815 output3 = table3.lookup(input_string) 1816 1817 out1, out2, out3 = sess.run([output1, output2, output3]) 1818 self.assertAllEqual([0, 1, -1], out1) 1819 self.assertAllEqual([0, 1, -1], out2) 1820 self.assertAllEqual([0, 1, -1], out3) 1821 1822 def testInitializeTableWithNoFilename(self): 1823 with self.test_session(): 1824 default_value = -1 1825 with self.assertRaises(ValueError): 1826 lookup.HashTable( 1827 lookup.TextFileInitializer( 1828 "", dtypes.string, lookup.TextFileIndex.WHOLE_LINE, 1829 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), 1830 default_value) 1831 1832 def testInitializeWithVocabSize(self): 1833 with self.test_session(): 1834 default_value = -1 1835 vocab_size = 3 1836 vocabulary_file1 = self._createVocabFile("one_column6.txt") 1837 table1 = lookup.HashTable( 1838 lookup.TextFileInitializer( 1839 vocabulary_file1, 1840 dtypes.string, 1841 lookup.TextFileIndex.WHOLE_LINE, 1842 dtypes.int64, 1843 lookup.TextFileIndex.LINE_NUMBER, 1844 vocab_size=vocab_size), 1845 default_value) 1846 1847 # Initialize from file. 1848 table1.init.run() 1849 self.assertEquals(vocab_size, table1.size().eval()) 1850 1851 vocabulary_file2 = self._createVocabFile("one_column7.txt") 1852 vocab_size = 5 1853 table2 = lookup.HashTable( 1854 lookup.TextFileInitializer( 1855 vocabulary_file2, 1856 dtypes.string, 1857 lookup.TextFileIndex.WHOLE_LINE, 1858 dtypes.int64, 1859 lookup.TextFileIndex.LINE_NUMBER, 1860 vocab_size=vocab_size), 1861 default_value) 1862 with self.assertRaisesOpError("Invalid vocab_size"): 1863 table2.init.run() 1864 1865 vocab_size = 1 1866 vocabulary_file3 = self._createVocabFile("one_column3.txt") 1867 table3 = lookup.HashTable( 1868 lookup.TextFileInitializer( 1869 vocabulary_file3, 1870 dtypes.string, 1871 lookup.TextFileIndex.WHOLE_LINE, 1872 dtypes.int64, 1873 lookup.TextFileIndex.LINE_NUMBER, 1874 vocab_size=vocab_size), 1875 default_value) 1876 1877 # Smaller vocab size reads only vocab_size records. 1878 table3.init.run() 1879 self.assertEquals(vocab_size, table3.size().eval()) 1880 1881 def testFeedVocabularyName(self): 1882 vocabulary_file = self._createVocabFile("feed_vocabulary.txt") 1883 1884 with self.test_session(): 1885 default_value = -1 1886 table = lookup.HashTable( 1887 lookup.TextFileInitializer("old_file.txt", dtypes.string, 1888 lookup.TextFileIndex.WHOLE_LINE, 1889 dtypes.int64, 1890 lookup.TextFileIndex.LINE_NUMBER), 1891 default_value) 1892 1893 # Initialize with non existing file (old_file.txt) should fail. 1894 # TODO(yleon): Update message, which might change per FileSystem. 1895 with self.assertRaisesOpError("old_file.txt"): 1896 table.init.run() 1897 1898 # Initialize the model feeding the vocabulary file. 1899 filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 1900 table.init.run(feed_dict={filenames[0]: vocabulary_file}) 1901 1902 input_string = constant_op.constant(["brain", "salad", "tank"]) 1903 output = table.lookup(input_string) 1904 1905 result = output.eval() 1906 self.assertAllEqual([0, 1, -1], result) 1907 1908 def testInvalidFilenames(self): 1909 vocabulary_file = self._createVocabFile("filename_shape.txt") 1910 1911 with self.test_session(): 1912 default_value = -1 1913 1914 # Invalid data type 1915 other_type = constant_op.constant(1) 1916 with self.assertRaises(ValueError): 1917 lookup.HashTable( 1918 lookup.TextFileInitializer( 1919 other_type, dtypes.string, lookup.TextFileIndex.WHOLE_LINE, 1920 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), 1921 default_value) 1922 1923 # Non-scalar filename 1924 filenames = constant_op.constant([vocabulary_file, vocabulary_file]) 1925 with self.assertRaises(ValueError): 1926 lookup.HashTable( 1927 lookup.TextFileInitializer( 1928 filenames, dtypes.string, lookup.TextFileIndex.WHOLE_LINE, 1929 dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), 1930 default_value) 1931 1932 def testIdToStringTable(self): 1933 vocab_file = self._createVocabFile("feat_to_id_1.txt") 1934 with self.test_session(): 1935 default_value = "UNK" 1936 vocab_size = 3 1937 table = lookup.HashTable( 1938 lookup.TextFileStringTableInitializer( 1939 vocab_file, vocab_size=vocab_size), 1940 default_value) 1941 1942 table.init.run() 1943 1944 input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1945 1946 out = table.lookup(input_values) 1947 self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], out.eval()) 1948 self.assertEquals(vocab_size, table.size().eval()) 1949 1950 def testStringToIdTable(self): 1951 vocab_file = self._createVocabFile("feat_to_id_2.txt") 1952 with self.test_session(): 1953 default_value = -1 1954 vocab_size = 3 1955 table = lookup.HashTable( 1956 lookup.TextFileIdTableInitializer( 1957 vocab_file, vocab_size=vocab_size), 1958 default_value) 1959 table.init.run() 1960 1961 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 1962 1963 out = table.lookup(input_string) 1964 self.assertAllEqual([0, 1, 2, -1], out.eval()) 1965 self.assertEquals(vocab_size, table.size().eval()) 1966 1967 def testInt64ToIdTable(self): 1968 vocab_file = self._createVocabFile( 1969 "feat_to_id_3.txt", values=("42", "1", "-1000")) 1970 with self.test_session(): 1971 default_value = -1 1972 vocab_size = 3 1973 table = lookup.HashTable( 1974 lookup.TextFileIdTableInitializer( 1975 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 1976 default_value) 1977 table.init.run() 1978 1979 out = table.lookup( 1980 constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) 1981 self.assertAllEqual((0, 1, 2, -1), out.eval()) 1982 self.assertEquals(vocab_size, table.size().eval()) 1983 1984 1985 class IdTableWithHashBucketsTest(test.TestCase): 1986 1987 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 1988 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 1989 with open(vocabulary_file, "w") as f: 1990 f.write("\n".join(values) + "\n") 1991 return vocabulary_file 1992 1993 def testStringIdTableWithHashBuckets(self): 1994 vocab_file = self._createVocabFile("feat_to_id_1.txt") 1995 with self.test_session(): 1996 default_value = -1 1997 vocab_size = 3 1998 oov_buckets = 1 1999 table = lookup.IdTableWithHashBuckets( 2000 lookup.HashTable( 2001 lookup.TextFileIdTableInitializer( 2002 vocab_file, vocab_size=vocab_size), 2003 default_value), 2004 oov_buckets) 2005 2006 table.init.run() 2007 2008 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 2009 2010 out = table.lookup(input_string) 2011 self.assertAllEqual([0, 1, 2, 3], out.eval()) 2012 self.assertEquals(vocab_size + oov_buckets, table.size().eval()) 2013 2014 def testInt32IdTableWithHashBuckets(self): 2015 vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) 2016 with self.test_session(): 2017 default_value = -1 2018 vocab_size = 3 2019 oov_buckets = 1 2020 table = lookup.IdTableWithHashBuckets( 2021 lookup.HashTable( 2022 lookup.TextFileIdTableInitializer( 2023 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 2024 default_value), 2025 oov_buckets, 2026 key_dtype=dtypes.int32) 2027 2028 table.init.run() 2029 2030 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) 2031 2032 out = table.lookup(values) 2033 self.assertAllEqual([0, 1, 2, 3], out.eval()) 2034 self.assertEquals(vocab_size + oov_buckets, table.size().eval()) 2035 2036 def testInt64IdTableWithHashBuckets(self): 2037 vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) 2038 with self.test_session(): 2039 default_value = -1 2040 vocab_size = 3 2041 oov_buckets = 1 2042 table = lookup.IdTableWithHashBuckets( 2043 lookup.HashTable( 2044 lookup.TextFileIdTableInitializer( 2045 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 2046 default_value), 2047 oov_buckets) 2048 2049 table.init.run() 2050 2051 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) 2052 2053 out = table.lookup(values) 2054 self.assertAllEqual([0, 1, 2, 3], out.eval()) 2055 self.assertEquals(vocab_size + oov_buckets, table.size().eval()) 2056 2057 def testStringIdTableWithOnlyHashBucket(self): 2058 with self.test_session(): 2059 oov_buckets = 5 2060 2061 # Set a table that only uses hash buckets, for each input value returns 2062 # an id calculated by fingerprint("input") mod oov_buckets. 2063 table = lookup.IdTableWithHashBuckets(None, oov_buckets) 2064 table.init.run() 2065 2066 values = constant_op.constant(("brain", "salad", "surgery")) 2067 2068 out = table.lookup(values) 2069 self.assertAllEqual( 2070 [ 2071 3, # fingerprint("brain") mod 5. 2072 1, # fingerprint("salad") mod 5. 2073 4 # fingerprint("surgery") mod 5 2074 ], 2075 out.eval()) 2076 self.assertEquals(oov_buckets, table.size().eval()) 2077 2078 def testInt32IdTableWithOnlyHashBucket(self): 2079 with self.test_session(): 2080 oov_buckets = 5 2081 2082 # Set a table that only uses hash buckets, for each input value returns 2083 # an id calculated by fingerprint("input") mod oov_buckets. 2084 table = lookup.IdTableWithHashBuckets( 2085 None, oov_buckets, key_dtype=dtypes.int32) 2086 table.init.run() 2087 2088 input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32) 2089 2090 out = table.lookup(input_string) 2091 self.assertAllEqual( 2092 [ 2093 1, # fingerprint("42") mod 5. 2094 4, # fingerprint("1") mod 5. 2095 2 # fingerprint("-1000") mod 5 2096 ], 2097 out.eval()) 2098 self.assertEquals(oov_buckets, table.size().eval()) 2099 2100 def testFloat64IdTableWithOnlyHashBucket(self): 2101 with self.test_session(): 2102 with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): 2103 lookup.IdTableWithHashBuckets( 2104 None, num_oov_buckets=5, key_dtype=dtypes.float64) 2105 2106 def testBoolIdTableWithOnlyHashBucket(self): 2107 with self.test_session(): 2108 with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): 2109 lookup.IdTableWithHashBuckets( 2110 None, num_oov_buckets=5, key_dtype=dtypes.bool) 2111 2112 def testIdTableWithHashBucketsWithMultipleInitializers(self): 2113 vocab_file = self._createVocabFile("feat_to_id_4.txt") 2114 with self.test_session() as sess: 2115 default_value = -1 2116 vocab_size = 3 2117 oov_buckets = 3 2118 2119 vocab_table = lookup.HashTable( 2120 lookup.TextFileIdTableInitializer( 2121 vocab_file, vocab_size=vocab_size), 2122 default_value) 2123 table1 = lookup.IdTableWithHashBuckets( 2124 vocab_table, 2125 oov_buckets, 2126 hasher_spec=lookup.FastHashSpec, 2127 name="table1") 2128 2129 table2 = lookup.IdTableWithHashBuckets( 2130 vocab_table, 2131 oov_buckets, 2132 hasher_spec=lookup.StrongHashSpec((1, 2)), 2133 name="table2") 2134 2135 lookup_ops.tables_initializer().run() 2136 2137 input_string = constant_op.constant( 2138 ["fruit", "brain", "salad", "surgery", "UNK"]) 2139 2140 out1 = table1.lookup(input_string) 2141 out2 = table2.lookup(input_string) 2142 2143 out1, out2 = sess.run([out1, out2]) 2144 self.assertAllEqual([5, 0, 1, 2, 5], out1) 2145 self.assertAllEqual([5, 0, 1, 2, 3], out2) 2146 self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) 2147 self.assertEquals(vocab_size + oov_buckets, table2.size().eval()) 2148 test_util.assert_ops_in_graph({ 2149 "table1_Lookup/hash_bucket": "StringToHashBucketFast", 2150 "table2_Lookup/hash_bucket": "StringToHashBucketStrong", 2151 }, sess.graph) 2152 2153 def testIdTableWithHashBucketsInitializationAcrossSessions(self): 2154 vocab_file = self._createVocabFile("feat_to_id_5.txt") 2155 shared_name = "across-sessions" 2156 with self.test_session(): 2157 default_value = -1 2158 vocab_size = 3 2159 oov_buckets = 1 2160 table1 = lookup.IdTableWithHashBuckets( 2161 lookup.HashTable( 2162 lookup.TextFileIdTableInitializer( 2163 vocab_file, vocab_size=vocab_size), 2164 default_value, 2165 shared_name=shared_name), 2166 oov_buckets) 2167 2168 table1.init.run() 2169 2170 input_string_1 = constant_op.constant( 2171 ["brain", "salad", "surgery", "UNK"]) 2172 2173 out1 = table1.lookup(input_string_1) 2174 2175 self.assertAllEqual([0, 1, 2, 3], out1.eval()) 2176 self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) 2177 2178 with self.test_session(): 2179 default_value = -1 2180 vocab_size = 3 2181 oov_buckets = 1 2182 2183 # Underlying lookup table already initialized in previous session. 2184 # No need to call table2.init.run() 2185 table2 = lookup.IdTableWithHashBuckets( 2186 lookup.HashTable( 2187 lookup.TextFileIdTableInitializer( 2188 vocab_file, vocab_size=vocab_size), 2189 default_value, 2190 shared_name=shared_name), 2191 oov_buckets) 2192 2193 input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) 2194 2195 out2 = table2.lookup(input_string_2) 2196 2197 self.assertAllEqual([3, 1, 3], out2.eval()) 2198 self.assertEquals(vocab_size + oov_buckets, table2.size().eval()) 2199 2200 def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): 2201 vocab_file = self._createVocabFile("feat_to_id_6.txt") 2202 with self.test_session() as sess: 2203 default_value1 = -1 2204 vocab_size = 3 2205 oov_buckets = 0 2206 table1 = lookup.IdTableWithHashBuckets( 2207 lookup.HashTable( 2208 lookup.TextFileIdTableInitializer( 2209 vocab_file, vocab_size=vocab_size), 2210 default_value1), 2211 oov_buckets) 2212 2213 default_value2 = -2 2214 table2 = lookup.IdTableWithHashBuckets( 2215 lookup.HashTable( 2216 lookup.TextFileIdTableInitializer( 2217 vocab_file, vocab_size=vocab_size), 2218 default_value2), 2219 oov_buckets) 2220 2221 lookup_ops.tables_initializer().run() 2222 2223 input_string_1 = constant_op.constant( 2224 ["brain", "salad", "surgery", "UNK"]) 2225 input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) 2226 2227 out1 = table1.lookup(input_string_1) 2228 out2 = table2.lookup(input_string_2) 2229 2230 out1, out2 = sess.run([out1, out2]) 2231 self.assertAllEqual([0, 1, 2, -1], out1) 2232 self.assertAllEqual([-2, 1, -2], out2) 2233 self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) 2234 self.assertEquals(vocab_size + oov_buckets, table2.size().eval()) 2235 2236 def testSparseTensor(self): 2237 vocab_file = self._createVocabFile("feat_to_id_7.txt") 2238 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 2239 input_shape = [4, 4] 2240 with self.test_session() as sess: 2241 sp_features = sparse_tensor.SparseTensor( 2242 constant_op.constant(input_indices, dtypes.int64), 2243 constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], 2244 dtypes.string), 2245 constant_op.constant(input_shape, dtypes.int64)) 2246 2247 table = lookup.IdTableWithHashBuckets( 2248 lookup.HashTable( 2249 lookup.TextFileIdTableInitializer( 2250 vocab_file, vocab_size=3), 2251 -1), 2252 1) 2253 table.init.run() 2254 2255 sp_ids = table.lookup(sp_features) 2256 2257 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 2258 2259 sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( 2260 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 2261 2262 self.assertAllEqual(input_indices, sp_ids_ind) 2263 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 2264 self.assertAllEqual(input_shape, sp_ids_shape) 2265 2266 def testInt32SparseTensor(self): 2267 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 2268 input_shape = [4, 4] 2269 with self.test_session() as sess: 2270 sp_features = sparse_tensor.SparseTensor( 2271 constant_op.constant(input_indices, dtypes.int64), 2272 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), 2273 constant_op.constant(input_shape, dtypes.int64)) 2274 2275 table = lookup.IdTableWithHashBuckets( 2276 lookup.HashTable( 2277 lookup.KeyValueTensorInitializer( 2278 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), 2279 -1), 2280 1, 2281 key_dtype=dtypes.int32) 2282 table.init.run() 2283 2284 sp_ids = table.lookup(sp_features) 2285 2286 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 2287 2288 sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( 2289 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 2290 2291 self.assertAllEqual(input_indices, sp_ids_ind) 2292 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 2293 self.assertAllEqual(input_shape, sp_ids_shape) 2294 2295 def testInt64SparseTensor(self): 2296 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 2297 input_shape = [4, 4] 2298 with self.test_session() as sess: 2299 sp_features = sparse_tensor.SparseTensor( 2300 constant_op.constant(input_indices, dtypes.int64), 2301 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), 2302 constant_op.constant(input_shape, dtypes.int64)) 2303 2304 table = lookup.IdTableWithHashBuckets( 2305 lookup.HashTable( 2306 lookup.KeyValueTensorInitializer( 2307 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), 2308 -1), 2309 1, 2310 key_dtype=dtypes.int64) 2311 table.init.run() 2312 2313 sp_ids = table.lookup(sp_features) 2314 2315 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 2316 2317 sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run( 2318 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 2319 2320 self.assertAllEqual(input_indices, sp_ids_ind) 2321 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 2322 self.assertAllEqual(input_shape, sp_ids_shape) 2323 2324 def testIdTableWithHashBucketsWithInvalidHashers(self): 2325 vocab_file = self._createVocabFile("feat_to_id_4.txt") 2326 with self.test_session(): 2327 default_value = -1 2328 vocab_size = 3 2329 oov_buckets = 1 2330 lookup_table = lookup.HashTable( 2331 lookup.TextFileIdTableInitializer( 2332 vocab_file, vocab_size=vocab_size), 2333 default_value) 2334 2335 with self.assertRaises(TypeError): 2336 lookup.IdTableWithHashBuckets( 2337 lookup_table, oov_buckets, hasher_spec=1) 2338 2339 table = lookup.IdTableWithHashBuckets( 2340 lookup_table, 2341 oov_buckets, 2342 hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) 2343 2344 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 2345 2346 with self.assertRaises(ValueError): 2347 table.lookup(input_string) 2348 2349 with self.assertRaises(ValueError): 2350 table = lookup.IdTableWithHashBuckets( 2351 lookup_table, 2352 oov_buckets, 2353 hasher_spec=lookup.StrongHashSpec([])) 2354 2355 with self.assertRaises(ValueError): 2356 table = lookup.IdTableWithHashBuckets( 2357 lookup_table, 2358 oov_buckets, 2359 hasher_spec=lookup.StrongHashSpec([1, 2, 3])) 2360 2361 with self.assertRaises(TypeError): 2362 table = lookup.IdTableWithHashBuckets( 2363 lookup_table, 2364 oov_buckets, 2365 hasher_spec=lookup.StrongHashSpec([None, 2])) 2366 2367 2368 if __name__ == "__main__": 2369 test.main() 2370