1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Lookup table operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.framework import tensor_shape 24 from tensorflow.python.ops import gen_lookup_ops 25 from tensorflow.python.ops import lookup_ops 26 # pylint: disable=unused-import 27 from tensorflow.python.ops.lookup_ops import FastHashSpec 28 from tensorflow.python.ops.lookup_ops import HasherSpec 29 from tensorflow.python.ops.lookup_ops import HashTable 30 from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets 31 from tensorflow.python.ops.lookup_ops import index_table_from_file 32 from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file 33 from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase 34 from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer 35 from tensorflow.python.ops.lookup_ops import LookupInterface 36 from tensorflow.python.ops.lookup_ops import StrongHashSpec 37 from tensorflow.python.ops.lookup_ops import TableInitializerBase 38 from tensorflow.python.ops.lookup_ops import TextFileIdTableInitializer 39 from tensorflow.python.ops.lookup_ops import TextFileIndex 40 from tensorflow.python.ops.lookup_ops import TextFileInitializer 41 from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer 42 # pylint: enable=unused-import 43 from tensorflow.python.training.saver import BaseSaverBuilder 44 from tensorflow.python.util.deprecation import deprecated 45 46 47 @deprecated("2017-04-10", "Use `index_table_from_file`.") 48 def string_to_index_table_from_file(vocabulary_file=None, 49 num_oov_buckets=0, 50 vocab_size=None, 51 default_value=-1, 52 hasher_spec=FastHashSpec, 53 name=None): 54 return index_table_from_file( 55 vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec, 56 key_dtype=dtypes.string, name=name) 57 58 59 @deprecated("2017-04-10", "Use `index_table_from_tensor`.") 60 def string_to_index_table_from_tensor(mapping, 61 num_oov_buckets=0, 62 default_value=-1, 63 hasher_spec=FastHashSpec, 64 name=None): 65 with ops.name_scope(name, "string_to_index") as scope: 66 mapping = ops.convert_to_tensor(mapping) 67 if dtypes.string != mapping.dtype.base_dtype: 68 raise ValueError("string_to_index_table_from_tensor requires string.") 69 return index_table_from_tensor( 70 mapping, num_oov_buckets, default_value, hasher_spec, name=scope) 71 72 73 def index_table_from_tensor(mapping, 74 num_oov_buckets=0, 75 default_value=-1, 76 hasher_spec=FastHashSpec, 77 dtype=dtypes.string, 78 name=None): 79 """Returns a lookup table that converts a string tensor into int64 IDs. 80 81 This operation constructs a lookup table to convert tensor of strings into 82 int64 IDs. The mapping can be initialized from a string `mapping` 1-D tensor 83 where each element is a key and corresponding index within the tensor is the 84 value. 85 86 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 87 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 88 `default_value`. 89 The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. 90 91 The underlying table must be initialized by calling 92 `tf.tables_initializer.run()` or `table.init.run()` once. 93 94 Elements in `mapping` cannot have duplicates, otherwise when executing the 95 table initializer op, it will throw a `FailedPreconditionError`. 96 97 Sample Usages: 98 99 ```python 100 mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 101 table = tf.contrib.lookup.index_table_from_tensor( 102 mapping=mapping_strings, num_oov_buckets=1, default_value=-1) 103 features = tf.constant(["emerson", "lake", "and", "palmer"]) 104 ids = table.lookup(features) 105 ... 106 tf.tables_initializer().run() 107 108 ids.eval() ==> [0, 1, 4, 2] 109 ``` 110 111 Args: 112 mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The 113 type of this object must be castable to `dtype`. 114 num_oov_buckets: The number of out-of-vocabulary buckets. 115 default_value: The value to use for out-of-vocabulary feature values. 116 Defaults to -1. 117 hasher_spec: A `HasherSpec` to specify the hash function to use for 118 assignment of out-of-vocabulary buckets. 119 dtype: The type of values passed to `lookup`. Only string and integers are 120 supported. 121 name: A name for this op (optional). 122 123 Returns: 124 The lookup table to map an input `Tensor` to index `int64` `Tensor`. 125 126 Raises: 127 ValueError: If `mapping` is invalid. 128 ValueError: If `num_oov_buckets` is negative. 129 """ 130 if mapping is None: 131 raise ValueError("mapping must be specified.") 132 return lookup_ops.index_table_from_tensor( 133 vocabulary_list=mapping, 134 num_oov_buckets=num_oov_buckets, 135 default_value=default_value, 136 hasher_spec=hasher_spec, 137 dtype=dtype, 138 name=name) 139 140 141 @deprecated( 142 "2017-01-07", "This op will be removed after the deprecation date. " 143 "Please switch to index_table_from_tensor and call the lookup " 144 "method of the returned table.") 145 def string_to_index(tensor, mapping, default_value=-1, name=None): 146 """Maps `tensor` of strings into `int64` indices based on `mapping`. 147 148 This operation converts `tensor` of strings into `int64` indices. 149 The mapping is initialized from a string `mapping` tensor where each element 150 is a key and corresponding index within the tensor is the value. 151 152 Any entry in the input which does not have a corresponding entry in 'mapping' 153 (an out-of-vocabulary entry) is assigned the `default_value` 154 155 Elements in `mapping` cannot be duplicated, otherwise the initialization 156 will throw a FailedPreconditionError. 157 158 The underlying table must be initialized by calling 159 `tf.tables_initializer.run()` once. 160 161 For example: 162 163 ```python 164 mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 165 feats = tf.constant(["emerson", "lake", "and", "palmer"]) 166 ids = tf.contrib.lookup.string_to_index( 167 feats, mapping=mapping_strings, default_value=-1) 168 ... 169 tf.tables_initializer().run() 170 171 ids.eval() ==> [0, 1, -1, 2] 172 ``` 173 174 Args: 175 tensor: A 1-D input `Tensor` with the strings to map to indices. 176 mapping: A 1-D string `Tensor` that specifies the mapping of strings to 177 indices. 178 default_value: The `int64` value to use for out-of-vocabulary strings. 179 Defaults to -1. 180 name: A name for this op (optional). 181 182 Returns: 183 The mapped indices. It has the same shape and tensor type (dense or sparse) 184 as `tensor`. 185 """ 186 table = index_table_from_tensor( 187 mapping=mapping, default_value=default_value, name=name) 188 return table.lookup(tensor) 189 190 191 def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): 192 """Returns a lookup table that maps a `Tensor` of indices into strings. 193 194 This operation constructs a lookup table to map int64 indices into string 195 values. The mapping is initialized from a string `mapping` 1-D `Tensor` where 196 each element is a value and the corresponding index within the tensor is the 197 key. 198 199 Any input which does not have a corresponding index in 'mapping' 200 (an out-of-vocabulary entry) is assigned the `default_value` 201 202 The underlying table must be initialized by calling 203 `tf.tables_initializer.run()` or `table.init.run()` once. 204 205 Elements in `mapping` cannot have duplicates, otherwise when executing the 206 table initializer op, it will throw a `FailedPreconditionError`. 207 208 Sample Usages: 209 210 ```python 211 mapping_string = tf.constant(["emerson", "lake", "palmer"]) 212 indices = tf.constant([1, 5], tf.int64) 213 table = tf.contrib.lookup.index_to_string_table_from_tensor( 214 mapping_string, default_value="UNKNOWN") 215 values = table.lookup(indices) 216 ... 217 tf.tables_initializer().run() 218 219 values.eval() ==> ["lake", "UNKNOWN"] 220 ``` 221 222 Args: 223 mapping: A 1-D string `Tensor` that specifies the strings to map from 224 indices. 225 default_value: The value to use for out-of-vocabulary indices. 226 name: A name for this op (optional). 227 228 Returns: 229 The lookup table to map a string values associated to a given index `int64` 230 `Tensors`. 231 232 Raises: 233 ValueError: when `mapping` is not set. 234 """ 235 236 if mapping is None: 237 raise ValueError("mapping must be specified.") 238 239 return lookup_ops.index_to_string_table_from_tensor( 240 vocabulary_list=mapping, default_value=default_value, name=name) 241 242 243 @deprecated( 244 "2017-01-07", "This op will be removed after the deprecation date. " 245 "Please switch to index_to_string_table_from_tensor and call the lookup " 246 "method of the returned table.") 247 def index_to_string(tensor, mapping, default_value="UNK", name=None): 248 """Maps `tensor` of indices into string values based on `mapping`. 249 250 This operation converts `int64` indices into string values. The mapping is 251 initialized from a string `mapping` tensor where each element is a value and 252 the corresponding index within the tensor is the key. 253 254 Any input which does not have a corresponding index in 'mapping' 255 (an out-of-vocabulary entry) is assigned the `default_value` 256 257 The underlying table must be initialized by calling 258 `tf.tables_initializer.run()` once. 259 260 For example: 261 262 ```python 263 mapping_string = tf.constant(["emerson", "lake", "palmer"]) 264 indices = tf.constant([1, 5], tf.int64) 265 values = tf.contrib.lookup.index_to_string( 266 indices, mapping=mapping_string, default_value="UNKNOWN") 267 ... 268 tf.tables_initializer().run() 269 270 values.eval() ==> ["lake", "UNKNOWN"] 271 ``` 272 273 Args: 274 tensor: A `int64` `Tensor` with the indices to map to strings. 275 mapping: A 1-D string `Tensor` that specifies the strings to map from 276 indices. 277 default_value: The string value to use for out-of-vocabulary indices. 278 name: A name for this op (optional). 279 280 Returns: 281 The strings values associated to the indices. The resultant dense 282 feature value tensor has the same shape as the corresponding `indices`. 283 """ 284 table = index_to_string_table_from_tensor( 285 mapping=mapping, default_value=default_value, name=name) 286 return table.lookup(tensor) 287 288 289 class MutableHashTable(LookupInterface): 290 """A generic mutable hash table implementation. 291 292 Data can be inserted by calling the insert method. It does not support 293 initialization via the init method. 294 295 Example usage: 296 297 ```python 298 table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, 299 value_dtype=tf.int64, 300 default_value=-1) 301 table.insert(keys, values) 302 out = table.lookup(query_keys) 303 print(out.eval()) 304 ``` 305 """ 306 307 def __init__(self, 308 key_dtype, 309 value_dtype, 310 default_value, 311 shared_name=None, 312 name="MutableHashTable", 313 checkpoint=True): 314 """Creates an empty `MutableHashTable` object. 315 316 Creates a table, the type of its keys and values are specified by key_dtype 317 and value_dtype, respectively. 318 319 Args: 320 key_dtype: the type of the key tensors. 321 value_dtype: the type of the value tensors. 322 default_value: The value to use if a key is missing in the table. 323 shared_name: If non-empty, this table will be shared under 324 the given name across multiple sessions. 325 name: A name for the operation (optional). 326 checkpoint: if True, the contents of the table are saved to and restored 327 from checkpoints. If `shared_name` is empty for a checkpointed table, it 328 is shared using the table node name. 329 330 Returns: 331 A `MutableHashTable` object. 332 333 Raises: 334 ValueError: If checkpoint is True and no name was specified. 335 """ 336 self._default_value = ops.convert_to_tensor(default_value, 337 dtype=value_dtype) 338 self._value_shape = self._default_value.get_shape() 339 340 # The table must be shared if checkpointing is requested for multi-worker 341 # training to work correctly. Use the node name if no shared_name has been 342 # explicitly specified. 343 use_node_name_sharing = checkpoint and shared_name is None 344 # pylint: disable=protected-access 345 if self._default_value.get_shape().ndims == 0: 346 self._table_ref = gen_lookup_ops._mutable_hash_table_v2( 347 shared_name=shared_name, 348 use_node_name_sharing=use_node_name_sharing, 349 key_dtype=key_dtype, 350 value_dtype=value_dtype, 351 name=name) 352 else: 353 self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors_v2( 354 shared_name=shared_name, 355 use_node_name_sharing=use_node_name_sharing, 356 key_dtype=key_dtype, 357 value_dtype=value_dtype, 358 value_shape=self._default_value.get_shape(), 359 name=name) 360 # pylint: enable=protected-access 361 super(MutableHashTable, self).__init__(key_dtype, value_dtype, 362 self._table_ref.op.name.split( 363 "/")[-1]) 364 365 if checkpoint: 366 saveable = MutableHashTable._Saveable(self, name) 367 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 368 369 def size(self, name=None): 370 """Compute the number of elements in this table. 371 372 Args: 373 name: A name for the operation (optional). 374 375 Returns: 376 A scalar tensor containing the number of elements in this table. 377 """ 378 with ops.name_scope(name, "%s_Size" % self._name, 379 [self._table_ref]) as name: 380 with ops.colocate_with(self._table_ref): 381 382 # pylint: disable=protected-access 383 return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) 384 385 def lookup(self, keys, name=None): 386 """Looks up `keys` in a table, outputs the corresponding values. 387 388 The `default_value` is used for keys not present in the table. 389 390 Args: 391 keys: Keys to look up. Can be a tensor of any shape. Must match the 392 table's key_dtype. 393 name: A name for the operation (optional). 394 395 Returns: 396 A tensor containing the values in the same shape as `keys` using the 397 table's value type. 398 399 Raises: 400 TypeError: when `keys` do not match the table data types. 401 """ 402 if keys.dtype.base_dtype != self._key_dtype: 403 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 404 (self._key_dtype, keys.dtype)) 405 406 with ops.name_scope(name, "%s_lookup_table_find" % self._name, 407 (self._table_ref, keys, self._default_value)) as name: 408 with ops.colocate_with(self._table_ref): 409 # pylint: disable=protected-access 410 values = gen_lookup_ops._lookup_table_find_v2( 411 self._table_ref, keys, self._default_value, name=name) 412 413 values.set_shape(keys.get_shape().concatenate(self._value_shape)) 414 return values 415 416 def insert(self, keys, values, name=None): 417 """Associates `keys` with `values`. 418 419 Args: 420 keys: Keys to insert. Can be a tensor of any shape. Must match the 421 table's key type. 422 values: Values to be associated with keys. Must be a tensor of the same 423 shape as `keys` and match the table's value type. 424 name: A name for the operation (optional). 425 426 Returns: 427 The created Operation. 428 429 Raises: 430 TypeError: when `keys` or `values` doesn't match the table data 431 types. 432 """ 433 # pylint: disable=protected-access 434 lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) 435 # pylint: enable=protected-access 436 with ops.name_scope(name, "%s_lookup_table_insert" % self._name, 437 [self._table_ref, keys, values]) as name: 438 with ops.colocate_with(self._table_ref): 439 # pylint: disable=protected-access 440 op = gen_lookup_ops._lookup_table_insert_v2( 441 self._table_ref, keys, values, name=name) 442 return op 443 444 def export(self, name=None): 445 """Returns tensors of all keys and values in the table. 446 447 Args: 448 name: A name for the operation (optional). 449 450 Returns: 451 A pair of tensors with the first tensor containing all keys and the 452 second tensors containing all values in the table. 453 """ 454 with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, 455 [self._table_ref]) as name: 456 with ops.colocate_with(self._table_ref): 457 # pylint: disable=protected-access 458 exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( 459 self._table_ref, self._key_dtype, self._value_dtype, name=name) 460 461 exported_values.set_shape(exported_keys.get_shape().concatenate( 462 self._value_shape)) 463 return exported_keys, exported_values 464 465 class _Saveable(BaseSaverBuilder.SaveableObject): 466 """SaveableObject implementation for MutableHashTable.""" 467 468 def __init__(self, table, name): 469 tensors = table.export() 470 specs = [ 471 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 472 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 473 ] 474 # pylint: disable=protected-access 475 super(MutableHashTable._Saveable, self).__init__(table, specs, name) 476 477 def restore(self, restored_tensors, unused_restored_shapes): 478 # pylint: disable=protected-access 479 with ops.colocate_with(self.op._table_ref): 480 return gen_lookup_ops._lookup_table_import_v2( 481 self.op._table_ref, restored_tensors[0], restored_tensors[1]) 482 483 484 class MutableDenseHashTable(LookupInterface): 485 """A generic mutable hash table implementation using tensors as backing store. 486 487 Data can be inserted by calling the insert method. It does not support 488 initialization via the init method. 489 490 It uses "open addressing" with quadratic reprobing to resolve collisions. 491 Compared to `MutableHashTable` the insert and lookup operations in a 492 `MutableDenseHashTable` are typically faster, but memory usage can be higher. 493 However, `MutableDenseHashTable` does not require additional memory for 494 temporary tensors created during checkpointing and restore operations. 495 496 Example usage: 497 498 ```python 499 table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, 500 value_dtype=tf.int64, 501 default_value=-1, 502 empty_key=0) 503 table.insert(keys, values) 504 out = table.lookup(query_keys) 505 print(out.eval()) 506 ``` 507 """ 508 509 # TODO(andreasst): consider extracting common code with MutableHashTable into 510 # a common superclass. 511 def __init__(self, 512 key_dtype, 513 value_dtype, 514 default_value, 515 empty_key, 516 initial_num_buckets=None, 517 shared_name=None, 518 name="MutableDenseHashTable", 519 checkpoint=True): 520 """Creates an empty `MutableDenseHashTable` object. 521 522 Creates a table, the type of its keys and values are specified by key_dtype 523 and value_dtype, respectively. 524 525 Args: 526 key_dtype: the type of the key tensors. 527 value_dtype: the type of the value tensors. 528 default_value: The value to use if a key is missing in the table. 529 empty_key: the key to use to represent empty buckets internally. Must not 530 be used in insert or lookup operations. 531 initial_num_buckets: the initial number of buckets. 532 shared_name: If non-empty, this table will be shared under 533 the given name across multiple sessions. 534 name: A name for the operation (optional). 535 checkpoint: if True, the contents of the table are saved to and restored 536 from checkpoints. If `shared_name` is empty for a checkpointed table, it 537 is shared using the table node name. 538 539 Returns: 540 A `MutableHashTable` object. 541 542 Raises: 543 ValueError: If checkpoint is True and no name was specified. 544 """ 545 self._default_value = ops.convert_to_tensor( 546 default_value, dtype=value_dtype) 547 self._value_shape = self._default_value.get_shape() 548 549 # The table must be shared if checkpointing is requested for multi-worker 550 # training to work correctly. Use the node name if no shared_name has been 551 # explicitly specified. 552 use_node_name_sharing = checkpoint and shared_name is None 553 empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) 554 # pylint: disable=protected-access 555 self._table_ref = gen_lookup_ops._mutable_dense_hash_table_v2( 556 empty_key=empty_key, 557 shared_name=shared_name, 558 use_node_name_sharing=use_node_name_sharing, 559 value_dtype=value_dtype, 560 value_shape=self._value_shape, 561 initial_num_buckets=initial_num_buckets, 562 name=name) 563 # pylint: enable=protected-access 564 super(MutableDenseHashTable, self).__init__( 565 key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1]) 566 567 if checkpoint: 568 saveable = MutableDenseHashTable._Saveable(self, name) 569 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 570 571 def size(self, name=None): 572 """Compute the number of elements in this table. 573 574 Args: 575 name: A name for the operation (optional). 576 577 Returns: 578 A scalar tensor containing the number of elements in this table. 579 """ 580 with ops.name_scope(name, "%s_Size" % self._name, 581 [self._table_ref]) as name: 582 with ops.colocate_with(self._table_ref): 583 # pylint: disable=protected-access 584 return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) 585 586 def lookup(self, keys, name=None): 587 """Looks up `keys` in a table, outputs the corresponding values. 588 589 The `default_value` is used for keys not present in the table. 590 591 Args: 592 keys: Keys to look up. Can be a tensor of any shape. Must match the 593 table's key_dtype. 594 name: A name for the operation (optional). 595 596 Returns: 597 A tensor containing the values in the same shape as `keys` using the 598 table's value type. 599 600 Raises: 601 TypeError: when `keys` do not match the table data types. 602 """ 603 if keys.dtype.base_dtype != self._key_dtype: 604 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 605 (self._key_dtype, keys.dtype)) 606 607 with ops.name_scope(name, "%s_lookup_table_find" % self._name, 608 [self._table_ref, keys]) as name: 609 with ops.colocate_with(self._table_ref): 610 # pylint: disable=protected-access 611 values = gen_lookup_ops._lookup_table_find_v2( 612 self._table_ref, keys, self._default_value, name=name) 613 614 if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: 615 values.set_shape( 616 tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate( 617 self._value_shape)) 618 return values 619 620 def insert(self, keys, values, name=None): 621 """Associates `keys` with `values`. 622 623 Args: 624 keys: Keys to insert. Can be a tensor of any shape. Must match the 625 table's key type. 626 values: Values to be associated with keys. Must be a tensor of the same 627 shape as `keys` and match the table's value type. 628 name: A name for the operation (optional). 629 630 Returns: 631 The created Operation. 632 633 Raises: 634 TypeError: when `keys` or `values` doesn't match the table data 635 types. 636 """ 637 # pylint: disable=protected-access 638 lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) 639 # pylint: enable=protected-access 640 with ops.name_scope(name, "%s_lookup_table_insert" % self._name, 641 [self._table_ref, keys, values]) as name: 642 with ops.colocate_with(self._table_ref): 643 # pylint: disable=protected-access 644 op = gen_lookup_ops._lookup_table_insert_v2( 645 self._table_ref, keys, values, name=name) 646 return op 647 648 def export(self, name=None): 649 """Returns tensors of all keys and values in the table. 650 651 Args: 652 name: A name for the operation (optional). 653 654 Returns: 655 A pair of tensors with the first tensor containing all keys and the 656 second tensors containing all values in the table. 657 """ 658 with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, 659 [self._table_ref]) as name: 660 with ops.colocate_with(self._table_ref): 661 # pylint: disable=protected-access 662 exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( 663 self._table_ref, self._key_dtype, self._value_dtype, name=name) 664 665 exported_values.set_shape(exported_keys.get_shape().concatenate( 666 self._value_shape)) 667 return exported_keys, exported_values 668 669 class _Saveable(BaseSaverBuilder.SaveableObject): 670 """SaveableObject implementation for MutableDenseHashTable.""" 671 672 def __init__(self, table, name): 673 tensors = table.export() 674 specs = [ 675 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 676 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 677 ] 678 # pylint: disable=protected-access 679 super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name) 680 681 def restore(self, restored_tensors, unused_restored_shapes): 682 # pylint: disable=protected-access 683 with ops.colocate_with(self.op._table_ref): 684 return gen_lookup_ops._lookup_table_import_v2( 685 self.op._table_ref, restored_tensors[0], restored_tensors[1]) 686