1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Embedding functions.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from six.moves import xrange # pylint: disable=redefined-builtin 21 22 from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util 23 from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op 24 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import ops 28 from tensorflow.python.framework import sparse_tensor 29 from tensorflow.python.framework import tensor_shape 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import clip_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import data_flow_ops 34 from tensorflow.python.ops import embedding_ops 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import resource_variable_ops 37 from tensorflow.python.ops import sparse_ops 38 from tensorflow.python.ops import variables 39 from tensorflow.python.platform import tf_logging as logging 40 41 __all__ = [ 42 "safe_embedding_lookup_sparse", "scattered_embedding_lookup", 43 "scattered_embedding_lookup_sparse", "embedding_lookup_unique", 44 "embedding_lookup_sparse_with_distributed_aggregation" 45 ] 46 47 48 def safe_embedding_lookup_sparse(embedding_weights, 49 sparse_ids, 50 sparse_weights=None, 51 combiner=None, 52 default_id=None, 53 name=None, 54 partition_strategy="div", 55 max_norm=None): 56 """Lookup embedding results, accounting for invalid IDs and empty features. 57 58 The partitioned embedding in `embedding_weights` must all be the same shape 59 except for the first dimension. The first dimension is allowed to vary as the 60 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 61 may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a 62 partitioner. 63 64 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 65 with non-positive weight. For an entry with no features, the embedding vector 66 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 67 68 The ids and weights may be multi-dimensional. Embeddings are always aggregated 69 along the last dimension. 70 71 Args: 72 embedding_weights: A list of `P` float tensors or values representing 73 partitioned embedding tensors. Alternatively, a `PartitionedVariable`, 74 created by partitioning along dimension 0. The total unpartitioned 75 shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the 76 vocab size and `e_1, ..., e_m` are the embedding dimensions. 77 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 78 ids. `d_0` is typically batch size. 79 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 80 float weights corresponding to `sparse_ids`, or `None` if all weights 81 are be assumed to be 1.0. 82 combiner: A string specifying how to combine embedding results for each 83 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" 84 the default. 85 default_id: The id to use for an entry with no features. 86 name: A name for this operation (optional). 87 partition_strategy: A string specifying the partitioning strategy. 88 Currently `"div"` and `"mod"` are supported. Default is `"div"`. 89 max_norm: If not None, all embeddings are l2-normalized to max_norm before 90 combining. 91 92 93 Returns: 94 Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. 95 96 Raises: 97 ValueError: if `embedding_weights` is empty. 98 """ 99 if combiner is None: 100 logging.warn("The default value of combiner will change from \"mean\" " 101 "to \"sqrtn\" after 2016/11/01.") 102 combiner = "mean" 103 if embedding_weights is None: 104 raise ValueError("Missing embedding_weights %s." % embedding_weights) 105 if isinstance(embedding_weights, variables.PartitionedVariable): 106 embedding_weights = list(embedding_weights) # get underlying Variables. 107 if not isinstance(embedding_weights, list): 108 embedding_weights = [embedding_weights] 109 if len(embedding_weights) < 1: 110 raise ValueError("Missing embedding_weights %s." % embedding_weights) 111 112 dtype = sparse_weights.dtype if sparse_weights is not None else None 113 if isinstance(embedding_weights, variables.PartitionedVariable): 114 embedding_weights = list(embedding_weights) 115 embedding_weights = [ 116 ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights 117 ] 118 119 contrib_tensor_util.assert_same_float_dtype(embedding_weights + 120 [sparse_weights]) 121 122 with ops.name_scope(name, "embedding_lookup", 123 embedding_weights + [sparse_ids, 124 sparse_weights]) as scope: 125 # Reshape higher-rank sparse ids and weights to linear segment ids. 126 original_shape = sparse_ids.dense_shape 127 original_rank_dim = sparse_ids.dense_shape.get_shape()[0] 128 original_rank = ( 129 array_ops.size(original_shape) 130 if original_rank_dim.value is None 131 else original_rank_dim.value) 132 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 133 math_ops.reduce_prod( 134 array_ops.slice(original_shape, [0], [original_rank - 1])), 135 array_ops.gather(original_shape, original_rank - 1)]) 136 if sparse_weights is not None: 137 sparse_weights = sparse_tensor.SparseTensor( 138 sparse_ids.indices, 139 sparse_weights.values, sparse_ids.dense_shape) 140 141 # Prune invalid ids and weights. 142 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 143 144 # Fill in dummy values for empty features, if necessary. 145 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, 146 default_id or 147 0) 148 if sparse_weights is not None: 149 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 150 151 result = embedding_ops.embedding_lookup_sparse( 152 embedding_weights, 153 sparse_ids, 154 sparse_weights, 155 combiner=combiner, 156 partition_strategy=partition_strategy, 157 name=None if default_id is None else scope, 158 max_norm=max_norm) 159 160 if default_id is None: 161 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 162 # for use in Select. 163 is_row_empty = array_ops.tile( 164 array_ops.reshape(is_row_empty, [-1, 1]), 165 array_ops.stack([1, array_ops.shape(result)[1]])) 166 167 result = array_ops.where(is_row_empty, 168 array_ops.zeros_like(result), 169 result, 170 name=scope) 171 172 # Reshape back from linear ids back into higher-dimensional dense result. 173 final_result = array_ops.reshape( 174 result, 175 array_ops.concat([ 176 array_ops.slice( 177 math_ops.cast(original_shape, dtypes.int32), [0], 178 [original_rank - 1]), 179 array_ops.slice(array_ops.shape(result), [1], [-1]) 180 ], 0)) 181 final_result.set_shape(tensor_shape.unknown_shape( 182 (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) 183 return final_result 184 185 186 def _prune_invalid_ids(sparse_ids, sparse_weights): 187 """Prune invalid IDs (< 0) from the input ids and weights.""" 188 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 189 if sparse_weights is not None: 190 is_id_valid = math_ops.logical_and( 191 is_id_valid, math_ops.greater(sparse_weights.values, 0)) 192 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 193 if sparse_weights is not None: 194 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 195 return sparse_ids, sparse_weights 196 197 198 def scattered_embedding_lookup(params, 199 values, 200 dimension, 201 name=None, 202 hash_key=None): 203 """Looks up embeddings using parameter hashing for each value in `values`. 204 205 The i-th embedding component of a value v in `values` is found by retrieving 206 the weight whose index is a fingerprint of the pair (v,i). 207 The concept is explored as "feature hashing" for model compression in this 208 paper: http://arxiv.org/pdf/1504.04788.pdf 209 210 Feature hashing has the pleasant effect of allowing us to compute an embedding 211 without needing a pre-determined vocabulary, relieving some amount of process 212 complexity. It also allows for us to maintain embeddings for possibly 213 trillions of features with a fixed amount of memory. 214 215 Note that this is superior to out-of-vocabulary shared "hash buckets" in that 216 the embedding is extremely likely to be unique for each token as opposed to 217 being shared across probably-colliding tokens. The price is that we must 218 compute a hash once for each scalar in the token's embedding as opposed to 219 once per token. 220 221 If `params` is a list, it represents a partition of the embedding parameters. 222 Each tensor in the list should have the same length, except for the first ones 223 which may have an additional element. For instance 10 parameters can be 224 partitioned in 4 tensors with length `[3, 3, 2, 2]`. 225 226 Args: 227 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 228 Each tensor must be of rank 1 with fully-defined shape. 229 values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. 230 dimension: Embedding dimension. 231 name: An optional name for this op. 232 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 233 function to combine the crosses fingerprints on SparseFeatureCrossOp 234 (optional). 235 236 Returns: 237 A `Tensor` with shape `[d0, ..., dn, dimension]`. 238 239 Raises: 240 ValueError: if dimension is not positive or the partition size is invalid. 241 """ 242 if dimension is None: 243 raise ValueError("You must specify dimension.") 244 return _sampled_scattered_embedding_lookup( 245 params, values, dimension=dimension, sampled_candidates=None, 246 hash_key=hash_key, name=name) 247 248 249 def _sampled_scattered_embedding_lookup( 250 params, values, dimension=None, sampled_candidates=None, hash_key=None, 251 name=None): 252 """Looks up embeddings using parameter hashing for each value in `values`. 253 254 This method looks up selected embedding dimensions if `sampled_candidates` is 255 given, otherwise looks up all dimensions. 256 257 The i-th embedding component of a value v in `values` is found by retrieving 258 the weight whose index is a fingerprint of the pair (v,i). 259 The concept is explored as "feature hashing" for model compression in this 260 paper: http://arxiv.org/pdf/1504.04788.pdf 261 262 Feature hashing has the pleasant effect of allowing us to compute an embedding 263 without needing a pre-determined vocabulary, relieving some amount of process 264 complexity. It also allows for us to maintain embeddings for possibly 265 trillions of features with a fixed amount of memory. 266 267 Note that this is superior to out-of-vocabulary shared "hash buckets" in that 268 the embedding is extremely likely to be unique for each token as opposed to 269 being shared across probably-colliding tokens. The price is that we must 270 compute a hash once for each scalar in the token's embedding as opposed to 271 once per token. 272 273 If `params` is a list, it represents a partition of the embedding parameters. 274 Each tensor in the list should have the same length, except for the first ones 275 which may have an additional element. For instance 10 parameters can be 276 partitioned in 4 tensors with length `[3, 3, 2, 2]`. 277 278 Args: 279 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 280 Each tensor must be of rank 1 with fully-defined shape. 281 values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. 282 dimension: Embedding dimension. The user must specify either `dimension` or 283 `sampled_candidates`. 284 sampled_candidates: An optional `Tensor` of slice indices to keep along the 285 final dimension with shape `[d0, ..., dn, N]`. If given, `dimension` is 286 ignored. If `None`, looks up all candidates. 287 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 288 function to combine the crosses fingerprints on SparseFeatureCrossOp 289 (optional). 290 name: An optional name for this op. 291 292 Returns: 293 A `Tensor` with shape `[d0, ..., dn, dimension]`. 294 If `sampled_candidates` is given, the output shape is `[d0, ..., dn, N]` 295 296 Raises: 297 ValueError: if dimension is not positive or the partition size is invalid. 298 """ 299 if isinstance(params, variables.PartitionedVariable): 300 params = list(params) 301 if not isinstance(params, list): 302 params = [params] 303 304 with ops.name_scope(name, "scattered_embedding_lookup", 305 params + [dimension, values]): 306 # Flatten the values 307 values_shape = array_ops.shape(values) 308 values = array_ops.reshape(values, [-1, 1]) 309 310 if sampled_candidates is None: 311 if dimension is None: 312 raise ValueError( 313 "You must specify either dimension or sampled_candidates.") 314 if dimension <= 0: 315 raise ValueError("Dimension must be >0. Given is %d" % dimension) 316 sampled_candidates = array_ops.tile(array_ops.expand_dims( 317 math_ops.range(0, dimension), 0), array_ops.shape(values)) 318 else: 319 dimension = array_ops.shape(sampled_candidates)[ 320 math_ops.subtract(array_ops.rank(sampled_candidates), 1)] 321 sampled_candidates_shape = array_ops.shape(sampled_candidates) 322 dimension_tensor = array_ops.reshape(dimension, shape=[1,]) 323 expected_shape = array_ops.concat([values_shape, dimension_tensor], 0) 324 with ops.control_dependencies([control_flow_ops.Assert( 325 math_ops.reduce_all(math_ops.equal(sampled_candidates_shape, 326 expected_shape)), 327 ["The shape of sampled_candidates: ", sampled_candidates_shape, 328 " does not match the shape of values: ", values_shape])]): 329 # Flatten sampled_candidates, same way as values are flattened. 330 sampled_candidates = array_ops.reshape(sampled_candidates, 331 [-1, dimension]) 332 333 num_partitions = len(params) 334 partition_sizes = [] 335 for p in range(num_partitions): 336 shape = params[p].get_shape() 337 shape.assert_has_rank(1) 338 shape.assert_is_fully_defined() 339 partition_sizes.append(shape[0].value) 340 num_params = sum(partition_sizes) # Total number of parameters. 341 342 # Assert the size of each partition. 343 for p in range(num_partitions): 344 expected_size = (num_params - p - 1) // num_partitions + 1 345 if partition_sizes[p] != expected_size: 346 raise ValueError("Tensor %d in params has size %d, expected %d." % 347 (p, partition_sizes[p], expected_size)) 348 349 # With two values v1 and v2 and 3 dimensions, we will cross 350 # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]]. 351 tensors_to_cross = [sampled_candidates, values] 352 ids = sparse_feature_cross_op.sparse_feature_cross( 353 tensors_to_cross, hashed_output=True, num_buckets=num_params, 354 hash_key=hash_key) 355 ids = sparse_ops.sparse_tensor_to_dense(ids) 356 357 # No need to validate the indices since we have checked the params 358 # dimensions and we know the largest id. 359 result = embedding_ops.embedding_lookup( 360 params, ids, partition_strategy="div") 361 362 return array_ops.reshape(result, 363 array_ops.concat([values_shape, [dimension]], 0)) 364 365 366 def scattered_embedding_lookup_sparse(params, 367 sparse_values, 368 dimension, 369 combiner=None, 370 default_value=None, 371 name=None, 372 hash_key=None): 373 """Looks up embeddings of a sparse feature using parameter hashing. 374 375 See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing. 376 377 Args: 378 params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. 379 Each tensor must be of rank 1 with fully-defined shape. 380 sparse_values: A 2-D `SparseTensor` containing the values to be embedded. 381 Some rows may be empty. 382 dimension: Embedding dimension 383 combiner: A string specifying how to combine embedding results for each 384 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" 385 the default. 386 default_value: The value to use for an entry with no features. 387 name: An optional name for this op. 388 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 389 function to combine the crosses fingerprints on SparseFeatureCrossOp 390 (optional). 391 392 Returns: 393 Dense tensor with shape [N, dimension] with N the number of rows in 394 sparse_values. 395 396 Raises: 397 TypeError: If sparse_values is not a SparseTensor. 398 ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. 399 """ 400 if combiner is None: 401 logging.warn("The default value of combiner will change from \"mean\" " 402 "to \"sqrtn\" after 2016/11/01.") 403 combiner = "mean" 404 if isinstance(params, variables.PartitionedVariable): 405 params = list(params) 406 if not isinstance(params, list): 407 params = [params] 408 if not isinstance(sparse_values, sparse_tensor.SparseTensor): 409 raise TypeError("sparse_values must be SparseTensor") 410 411 with ops.name_scope(name, "scattered_embedding_lookup_sparse", 412 params + [sparse_values]) as scope: 413 # Fill in the empty rows. 414 if default_value is None: 415 # Random default values to reduce the risk of collision. 416 if sparse_values.dtype == dtypes.string: 417 default_value = "6ZxWzWOHxZ" 418 else: 419 default_value = 1288896567 420 sparse_values, _ = sparse_ops.sparse_fill_empty_rows( 421 sparse_values, default_value) 422 423 segment_ids = sparse_values.indices[:, 0] 424 if segment_ids.dtype != dtypes.int32: 425 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 426 427 values = sparse_values.values 428 values, idx = array_ops.unique(values) 429 430 embeddings = scattered_embedding_lookup( 431 params, values, dimension, hash_key=hash_key) 432 433 if combiner == "sum": 434 embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, 435 name=scope) 436 elif combiner == "mean": 437 embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, 438 name=scope) 439 elif combiner == "sqrtn": 440 embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids, 441 name=scope) 442 else: 443 raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") 444 445 return embeddings 446 447 448 def embedding_lookup_unique(params, ids, name=None): 449 """Version of embedding_lookup that avoids duplicate lookups. 450 451 This can save communication in the case of repeated ids. 452 Same interface as embedding_lookup. Except it supports multi-dimensional `ids` 453 which allows to not reshape input/output to fit gather. 454 455 Args: 456 params: A list of tensors with the same shape and type, or a 457 `PartitionedVariable`. Shape `[index, d1, d2, ...]`. 458 ids: A one-dimensional `Tensor` with type `int32` or `int64` containing 459 the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. 460 name: A name for this operation (optional). 461 462 Returns: 463 A `Tensor` with the same type as the tensors in `params` and dimension of 464 `[ids1, ids2, d1, d2, ...]`. 465 466 Raises: 467 ValueError: If `params` is empty. 468 """ 469 with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]): 470 ids = ops.convert_to_tensor(ids) 471 shape = array_ops.shape(ids) 472 ids_flat = array_ops.reshape( 473 ids, math_ops.reduce_prod(shape, keep_dims=True)) 474 unique_ids, idx = array_ops.unique(ids_flat) 475 unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) 476 embeds_flat = array_ops.gather(unique_embeddings, idx) 477 embed_shape = array_ops.concat( 478 [shape, array_ops.shape(unique_embeddings)[1:]], 0) 479 embeds = array_ops.reshape(embeds_flat, embed_shape) 480 embeds.set_shape(ids.get_shape().concatenate( 481 unique_embeddings.get_shape()[1:])) 482 return embeds 483 484 485 def _sampled_scattered_embedding_lookup_sparse(params, 486 sp_values, 487 dimension=None, 488 sampled_candidates=None, 489 hash_key=None, 490 with_sign_hash=False, 491 name=None): 492 """Looks up embeddings using parameter hashing for sparse values. 493 494 This method looks up selected embedding dimensions if `sampled_candidates` is 495 given, otherwise looks up all dimensions. 496 497 The i-th embedding component of a value v in `values` is found by retrieving 498 the weight whose index is a fingerprint of the pair (v,i). 499 The concept is explored as "feature hashing" for model compression in this 500 paper: http://arxiv.org/pdf/1504.04788.pdf 501 502 This is logically equivalent to: 503 * Transforming `sp_values` (which has shape `[d0, d1]`) into a one-hot 504 `Tensor` of shape `[d0, N]`. 505 * Multiplying with a `Tensor` `h` of shape `[N, dimension]`, where 506 `h(i, j) = params[hash(i, j)]`. 507 508 Args: 509 params: A float `Tensor` with rank 1 and fully-defined shape. 510 sp_values: A 2D `SparseTensor` to be embedded with shape `[d0, d1]`. 511 dimension: An int `Tensor` of the final dimension. The user needs to provide 512 either `dimension` or `sampled_candidates`. 513 sampled_candidates: An optional `Tensor` of column indices to keep along 514 the final dimension with shape `[d0, N]`. If given, `dimension` is 515 ignored. If `None`, looks up all candidates. 516 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 517 function to combine the crosses fingerprints on SparseFeatureCrossOp 518 (optional). 519 with_sign_hash: A `bool` indicating whether `h(i, j)` should be multiplied 520 by `+1` or `-1`, where the value selected is determined by hashing 521 `(i, j)`. This is often necessary to remove bias resulting from hash 522 collisions. 523 name: An optional name for this op. 524 525 Returns: 526 A `Tensor` of shape `[d0, dimension]`. 527 If `sampled_candidates` is given, the output shape is `[d0, N]`. 528 529 Raises: 530 TypeError: If sp_values is not `SparseTensor`. 531 ValueError: If both `dimension` and `sampled_candidates` are `None`. 532 """ 533 if not isinstance(sp_values, sparse_tensor.SparseTensor): 534 raise TypeError("sp_values must be SparseTensor") 535 536 with ops.name_scope( 537 name=name, 538 default_name="sampled_scattered_embedding_lookup_sparse", 539 values=[sp_values, params, dimension, sampled_candidates]) as name_scope: 540 segment_ids = sp_values.indices[:, 0] 541 if sampled_candidates is not None: 542 # Tile sampled_candidates so there is one line corresponding to each 543 # element in sp_values.values 544 sampled_candidates = array_ops.gather(sampled_candidates, segment_ids) 545 546 embeddings = _sampled_scattered_embedding_lookup( 547 params, sp_values.values, dimension=dimension, 548 sampled_candidates=sampled_candidates, 549 hash_key=hash_key, name="values_lookup") 550 if with_sign_hash: 551 signs = _sampled_scattered_embedding_lookup( 552 array_ops.constant([-1., 1.]), sp_values.values, dimension=dimension, 553 sampled_candidates=sampled_candidates, hash_key=hash_key, 554 name="signs_lookup") 555 embeddings = math_ops.multiply(signs, embeddings, name="signs_hash") 556 557 if segment_ids.dtype != dtypes.int32: 558 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 559 num_segments = array_ops.shape(sp_values)[0] 560 561 return math_ops.unsorted_segment_sum(embeddings, segment_ids, 562 num_segments=num_segments, 563 name=name_scope) 564 565 566 def embedding_lookup_sparse_with_distributed_aggregation( 567 params, 568 sp_ids, 569 sp_weights, 570 partition_strategy="mod", 571 name=None, 572 combiner=None, 573 max_norm=None): 574 """Computes embeddings for the given ids and weights. 575 576 Embeddings belonging to same param are aggregated on that device first. This 577 op is intended to decrease data transmission and improve parallelism. See 578 `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. 579 580 Args: 581 params: A single tensor representing the complete embedding tensor, 582 or a list of P tensors all of same shape except for the first dimension, 583 representing sharded embedding tensors. Alternatively, a 584 `PartitionedVariable`, created by partitioning along dimension 0. Each 585 element must be appropriately sized for the given `partition_strategy`. 586 sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), 587 where N is typically batch size and M is arbitrary. 588 sp_weights: either a SparseTensor of float / double weights, or None to 589 indicate all weights should be taken to be 1. If specified, sp_weights 590 must have exactly the same shape and indices as sp_ids. 591 partition_strategy: A string specifying the partitioning strategy, relevant 592 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 593 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 594 name: Optional name for the op. 595 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 596 and "sum" are supported. 597 "sum" computes the weighted sum of the embedding results for each row. 598 "mean" is the weighted sum divided by the total weight. 599 "sqrtn" is the weighted sum divided by the square root of the sum of the 600 squares of the weights. 601 max_norm: If not None, each embedding is normalized to have l2 norm equal 602 to max_norm before combining. 603 604 Returns: 605 A dense tensor representing the combined embeddings for the 606 sparse ids. For each row in the dense tensor represented by sp_ids, the op 607 looks up the embeddings for all ids in that row, multiplies them by the 608 corresponding weight, and combines these embeddings as specified. 609 610 Raises: 611 TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither 612 None nor SparseTensor. 613 ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. 614 """ 615 if combiner is None: 616 logging.warn("The default value of combiner will change from \"mean\" " 617 "to \"sqrtn\" after 2016/11/01.") 618 combiner = "mean" 619 if combiner not in ("mean", "sqrtn", "sum"): 620 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 621 if isinstance(params, variables.PartitionedVariable): 622 params = list(params) # Iterate to get the underlying Variables. 623 if not isinstance(params, list): 624 params = [params] 625 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 626 raise TypeError("sp_ids must be SparseTensor") 627 ignore_weights = sp_weights is None 628 if not ignore_weights: 629 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 630 raise TypeError("sp_weights must be either None or SparseTensor") 631 sp_ids.values.get_shape().assert_is_compatible_with( 632 sp_weights.values.get_shape()) 633 sp_ids.indices.get_shape().assert_is_compatible_with( 634 sp_weights.indices.get_shape()) 635 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 636 sp_weights.dense_shape.get_shape()) 637 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 638 # sp_weights have equal indices and shapes. 639 640 with ops.name_scope(name, "embedding_lookup_sparse", 641 params + [sp_ids]) as name: 642 segment_ids = sp_ids.indices[:, 0] 643 if segment_ids.dtype != dtypes.int32: 644 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 645 646 ids = sp_ids.values 647 if ignore_weights: 648 ids, idx = array_ops.unique(ids) 649 else: 650 idx = None 651 652 weights = None if ignore_weights else sp_weights.values 653 embeddings = _embedding_lookup_with_distributed_aggregation( 654 params, 655 ids, 656 partition_strategy=partition_strategy, 657 max_norm=max_norm, 658 weights=weights, 659 idx=idx, 660 segment_ids=segment_ids) 661 # Set weights to all one if ignore weights. 662 if ignore_weights: 663 weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) 664 if weights.dtype != embeddings.dtype: 665 weights = math_ops.cast(weights, embeddings.dtype) 666 # Reshape weights. 667 ones = array_ops.fill( 668 array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) 669 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) 670 orig_weights_shape = weights.get_shape() 671 weights = array_ops.reshape(weights, bcast_weights_shape) 672 if embeddings.get_shape().ndims is not None: 673 weights.set_shape( 674 orig_weights_shape.concatenate( 675 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 676 677 if combiner == "mean": 678 weight_sum = math_ops.segment_sum(weights, segment_ids) 679 embeddings = math_ops.div(embeddings, weight_sum) 680 elif combiner == "sqrtn": 681 weights_squared = math_ops.pow(weights, 2) 682 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 683 weight_sum_sqrt = math_ops.sqrt(weight_sum) 684 embeddings = math_ops.div(embeddings, weight_sum_sqrt) 685 elif combiner != "sum": 686 assert False, "Unrecognized combiner" 687 return embeddings 688 689 690 def _do_gather(params, ids, name=None): 691 """Deals with doing gather differently for resource variables.""" 692 if isinstance(params, resource_variable_ops.ResourceVariable): 693 return params.sparse_read(ids, name=name) 694 return array_ops.gather(params, ids, name=name) 695 696 697 def _embedding_lookup_with_distributed_aggregation(params, 698 ids, 699 partition_strategy="mod", 700 name=None, 701 max_norm=None, 702 weights=None, 703 idx=None, 704 segment_ids=None): 705 """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" 706 if params is None or params == []: # pylint: disable=g-explicit-bool-comparison 707 raise ValueError("Need at least one param") 708 if isinstance(params, variables.PartitionedVariable): 709 params = list(params) # Iterate to get the underlying Variables. 710 if not isinstance(params, list): 711 params = [params] 712 713 def maybe_normalize(x): 714 if max_norm is not None: 715 if x.get_shape().ndims is not None: 716 ndims = x.get_shape().ndims 717 else: 718 ndims = array_ops.size(array_ops.shape(x)) 719 return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) 720 return x 721 722 with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", 723 params + [ids]) as name: 724 np = len(params) # Number of partitions 725 # Preserve the resource variable status to avoid accidental dense reads. 726 if not any( 727 isinstance(p, resource_variable_ops.ResourceVariable) for p in params): 728 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 729 if np == 1: 730 with ops.colocate_with(params[0]): 731 ret = maybe_normalize(_do_gather(params[0], ids)) 732 ignore_weights = weights is None 733 if not ignore_weights: 734 if weights.dtype != ret.dtype: 735 weights = math_ops.cast(weights, ret.dtype) 736 # Reshape to allow broadcast 737 ones = array_ops.fill( 738 array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) 739 bcast_weights_shape = array_ops.concat( 740 [array_ops.shape(weights), ones], 0) 741 orig_weights_shape = weights.get_shape() 742 weights = array_ops.reshape(weights, bcast_weights_shape) 743 # Set weights shape after reshape 744 if ret.get_shape().ndims is not None: 745 weights.set_shape( 746 orig_weights_shape.concatenate( 747 [1 for _ in range(ret.get_shape().ndims - 1)])) 748 ret *= weights 749 return math_ops.segment_sum(ret, segment_ids, name=name) 750 else: 751 return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) 752 else: 753 ids = ops.convert_to_tensor(ids, name="ids") 754 flat_ids = array_ops.reshape(ids, [-1]) 755 original_indices = math_ops.range(array_ops.size(flat_ids)) 756 757 # Create p_assignments and set new_ids depending on the strategy. 758 if partition_strategy == "mod": 759 p_assignments = flat_ids % np 760 new_ids = flat_ids // np 761 elif partition_strategy == "div": 762 # Compute num_total_ids as the sum of dim-0 of params, then assign to 763 # partitions based on a constant number of ids per partition. Optimize 764 # if we already know the full shape statically. 765 dim_0_size = params[0].get_shape()[0] 766 for p in xrange(1, np): 767 dim_0_size += params[p].get_shape()[0] 768 if dim_0_size.value: 769 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 770 else: 771 dim_0_sizes = [] 772 for p in xrange(np): 773 if params[p].get_shape()[0].value is not None: 774 dim_0_sizes.append(params[p].get_shape()[0].value) 775 else: 776 with ops.colocate_with(params[p]): 777 dim_0_sizes.append(array_ops.shape(params[p])[0]) 778 num_total_ids = math_ops.reduce_sum( 779 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 780 ids_per_partition = num_total_ids // np 781 extras = num_total_ids % np 782 783 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), ( 784 flat_ids - extras) // ids_per_partition) 785 786 # Emulate a conditional using a boolean indicator tensor 787 is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, 788 flat_ids.dtype) 789 new_ids = (is_in_first_extras_partitions * (flat_ids % 790 (ids_per_partition + 1)) + 791 (1 - is_in_first_extras_partitions) * ( 792 (flat_ids - extras) % ids_per_partition)) 793 else: 794 raise ValueError("Unrecognized partition strategy: " + 795 partition_strategy) 796 797 # Cast partition assignments to int32 for use in dynamic_partition. 798 # There really should not be more than 2^32 partitions. 799 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 800 # Partition list of ids based on assignments into np separate lists 801 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 802 # Similarly, partition the original indices. 803 pindices = data_flow_ops.dynamic_partition(original_indices, 804 p_assignments, np) 805 # Do np separate lookups, finding embeddings for plist[p] in params[p] 806 partitioned_result = [] 807 for p in xrange(np): 808 with ops.colocate_with(params[p]): 809 partitioned_result.append(_do_gather(params[p], gather_ids[p])) 810 811 ignore_weights = weights is None 812 if not ignore_weights: 813 # Partition weights according to pindices. 814 partitioned_weight = [] 815 for p in xrange(np): 816 partitioned_weight.append(array_ops.gather(weights, pindices[p])) 817 # Reshape each partition result. 818 element_shape = params[0].get_shape()[1:] 819 for p in params[1:]: 820 element_shape = element_shape.merge_with(p.get_shape()[1:]) 821 if element_shape.is_fully_defined(): 822 for p in xrange(np): 823 with ops.colocate_with(params[p]): 824 partitioned_result[p] = array_ops.reshape( 825 partitioned_result[p], 826 array_ops.concat([array_ops.shape(pindices[p]), element_shape], 827 0)) 828 else: 829 with ops.colocate_with(params[0]): 830 params_shape = array_ops.shape(params[0]) 831 for p in xrange(np): 832 with ops.colocate_with(params[p]): 833 partitioned_result[p] = array_ops.reshape( 834 partitioned_result[p], 835 array_ops.concat([ 836 array_ops.shape(pindices[p]), array_ops.slice( 837 params_shape, [1], [-1]) 838 ], 0)) 839 # Normalize each partition result. 840 for p in xrange(np): 841 with ops.colocate_with(params[p]): 842 partitioned_result[p] = maybe_normalize(partitioned_result[p]) 843 if not ignore_weights: 844 # Multiply each partition result with partition weights. 845 for p in xrange(np): 846 with ops.colocate_with(params[p]): 847 if partitioned_weight[p].dtype != partitioned_result[p].dtype: 848 partitioned_weight[p] = math_ops.cast(partitioned_weight[p], 849 partitioned_result[p].dtype) 850 # Reshape partition weights. 851 ones = array_ops.fill( 852 array_ops.expand_dims( 853 array_ops.rank(partitioned_result[p]) - 1, 0), 1) 854 bcast_weights_shape = array_ops.concat( 855 [array_ops.shape(partitioned_weight[p]), ones], 0) 856 orig_weights_shape = partitioned_weight[p].get_shape() 857 partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], 858 bcast_weights_shape) 859 if partitioned_result[p].get_shape().ndims is not None: 860 partitioned_weight[p].set_shape( 861 orig_weights_shape.concatenate([ 862 1 863 for _ in range(partitioned_result[p].get_shape().ndims - 864 1) 865 ])) 866 partitioned_result[p] *= partitioned_weight[p] 867 partitioned_segment_ids = [] 868 for p in xrange(np): 869 if not ignore_weights: 870 # Partition segment_ids according to pindices. 871 p_segment_ids = array_ops.gather(segment_ids, pindices[p]) 872 # Number the p_segment_ids to meet segment_sum's requirements. Note 873 # that unique_p_segment_ids contains unique segment ids of this 874 # partition and these ids' order is unchanged. 875 unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( 876 p_segment_ids) 877 partitioned_segment_ids.append(unique_p_segment_ids) 878 # segment_sum this partition's result. 879 with ops.colocate_with(params[p]): 880 partitioned_result[p] = math_ops.segment_sum( 881 partitioned_result[p], unique_p_segment_idx) 882 else: 883 # When ignore weights, we need to get indexs of elements in idx and 884 # segment_ids. 885 _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) 886 all_idx = math_ops.range(array_ops.shape(idx)[0]) 887 _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) 888 # Gather segment_ids and idx according to indexs. 889 p_segment_ids = array_ops.gather(segment_ids, include_idx) 890 p_idx = array_ops.gather(idx, include_idx) 891 # Number the p_segment_ids, same as ignore_weights case above. 892 unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( 893 p_segment_ids) 894 _, unique_p_idx_idx = array_ops.unique(p_idx) 895 partitioned_segment_ids.append(unique_p_segment_ids) 896 with ops.colocate_with(params[p]): 897 partitioned_result[p] = math_ops.sparse_segment_sum( 898 partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) 899 # Concat each partition's segment_ids and result for final segment_sum. 900 concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) 901 concat_partitioned_result = array_ops.concat(partitioned_result, 0) 902 return math_ops.unsorted_segment_sum( 903 concat_partitioned_result, 904 concat_segment_ids, 905 math_ops.reduce_max(concat_segment_ids) + 1, 906 name=name) 907