Home | History | Annotate | Download | only in layers
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Embedding functions."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from six.moves import xrange  # pylint: disable=redefined-builtin
     21 
     22 from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
     23 from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
     24 
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import sparse_tensor
     29 from tensorflow.python.framework import tensor_shape
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import clip_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import data_flow_ops
     34 from tensorflow.python.ops import embedding_ops
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import resource_variable_ops
     37 from tensorflow.python.ops import sparse_ops
     38 from tensorflow.python.ops import variables
     39 from tensorflow.python.platform import tf_logging as logging
     40 
     41 __all__ = [
     42     "safe_embedding_lookup_sparse", "scattered_embedding_lookup",
     43     "scattered_embedding_lookup_sparse", "embedding_lookup_unique",
     44     "embedding_lookup_sparse_with_distributed_aggregation"
     45 ]
     46 
     47 
     48 def safe_embedding_lookup_sparse(embedding_weights,
     49                                  sparse_ids,
     50                                  sparse_weights=None,
     51                                  combiner=None,
     52                                  default_id=None,
     53                                  name=None,
     54                                  partition_strategy="div",
     55                                  max_norm=None):
     56   """Lookup embedding results, accounting for invalid IDs and empty features.
     57 
     58   The partitioned embedding in `embedding_weights` must all be the same shape
     59   except for the first dimension. The first dimension is allowed to vary as the
     60   vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
     61   may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
     62   partitioner.
     63 
     64   Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
     65   with non-positive weight. For an entry with no features, the embedding vector
     66   for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
     67 
     68   The ids and weights may be multi-dimensional. Embeddings are always aggregated
     69   along the last dimension.
     70 
     71   Args:
     72     embedding_weights:  A list of `P` float tensors or values representing
     73         partitioned embedding tensors.  Alternatively, a `PartitionedVariable`,
     74         created by partitioning along dimension 0.  The total unpartitioned
     75         shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
     76         vocab size and `e_1, ..., e_m` are the embedding dimensions.
     77     sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
     78         ids. `d_0` is typically batch size.
     79     sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
     80         float weights corresponding to `sparse_ids`, or `None` if all weights
     81         are be assumed to be 1.0.
     82     combiner: A string specifying how to combine embedding results for each
     83         entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
     84         the default.
     85     default_id: The id to use for an entry with no features.
     86     name: A name for this operation (optional).
     87     partition_strategy: A string specifying the partitioning strategy.
     88         Currently `"div"` and `"mod"` are supported. Default is `"div"`.
     89     max_norm: If not None, all embeddings are l2-normalized to max_norm before
     90         combining.
     91 
     92 
     93   Returns:
     94     Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
     95 
     96   Raises:
     97     ValueError: if `embedding_weights` is empty.
     98   """
     99   if combiner is None:
    100     logging.warn("The default value of combiner will change from \"mean\" "
    101                  "to \"sqrtn\" after 2016/11/01.")
    102     combiner = "mean"
    103   if embedding_weights is None:
    104     raise ValueError("Missing embedding_weights %s." % embedding_weights)
    105   if isinstance(embedding_weights, variables.PartitionedVariable):
    106     embedding_weights = list(embedding_weights)  # get underlying Variables.
    107   if not isinstance(embedding_weights, list):
    108     embedding_weights = [embedding_weights]
    109   if len(embedding_weights) < 1:
    110     raise ValueError("Missing embedding_weights %s." % embedding_weights)
    111 
    112   dtype = sparse_weights.dtype if sparse_weights is not None else None
    113   if isinstance(embedding_weights, variables.PartitionedVariable):
    114     embedding_weights = list(embedding_weights)
    115   embedding_weights = [
    116       ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
    117   ]
    118 
    119   contrib_tensor_util.assert_same_float_dtype(embedding_weights +
    120                                               [sparse_weights])
    121 
    122   with ops.name_scope(name, "embedding_lookup",
    123                       embedding_weights + [sparse_ids,
    124                                            sparse_weights]) as scope:
    125     # Reshape higher-rank sparse ids and weights to linear segment ids.
    126     original_shape = sparse_ids.dense_shape
    127     original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
    128     original_rank = (
    129         array_ops.size(original_shape)
    130         if original_rank_dim.value is None
    131         else original_rank_dim.value)
    132     sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
    133         math_ops.reduce_prod(
    134             array_ops.slice(original_shape, [0], [original_rank - 1])),
    135         array_ops.gather(original_shape, original_rank - 1)])
    136     if sparse_weights is not None:
    137       sparse_weights = sparse_tensor.SparseTensor(
    138           sparse_ids.indices,
    139           sparse_weights.values, sparse_ids.dense_shape)
    140 
    141     # Prune invalid ids and weights.
    142     sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
    143 
    144     # Fill in dummy values for empty features, if necessary.
    145     sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
    146                                                                  default_id or
    147                                                                  0)
    148     if sparse_weights is not None:
    149       sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
    150 
    151     result = embedding_ops.embedding_lookup_sparse(
    152         embedding_weights,
    153         sparse_ids,
    154         sparse_weights,
    155         combiner=combiner,
    156         partition_strategy=partition_strategy,
    157         name=None if default_id is None else scope,
    158         max_norm=max_norm)
    159 
    160     if default_id is None:
    161       # Broadcast is_row_empty to the same shape as embedding_lookup_result,
    162       # for use in Select.
    163       is_row_empty = array_ops.tile(
    164           array_ops.reshape(is_row_empty, [-1, 1]),
    165           array_ops.stack([1, array_ops.shape(result)[1]]))
    166 
    167       result = array_ops.where(is_row_empty,
    168                                array_ops.zeros_like(result),
    169                                result,
    170                                name=scope)
    171 
    172     # Reshape back from linear ids back into higher-dimensional dense result.
    173     final_result = array_ops.reshape(
    174         result,
    175         array_ops.concat([
    176             array_ops.slice(
    177                 math_ops.cast(original_shape, dtypes.int32), [0],
    178                 [original_rank - 1]),
    179             array_ops.slice(array_ops.shape(result), [1], [-1])
    180         ], 0))
    181     final_result.set_shape(tensor_shape.unknown_shape(
    182         (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
    183     return final_result
    184 
    185 
    186 def _prune_invalid_ids(sparse_ids, sparse_weights):
    187   """Prune invalid IDs (< 0) from the input ids and weights."""
    188   is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
    189   if sparse_weights is not None:
    190     is_id_valid = math_ops.logical_and(
    191         is_id_valid, math_ops.greater(sparse_weights.values, 0))
    192   sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
    193   if sparse_weights is not None:
    194     sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
    195   return sparse_ids, sparse_weights
    196 
    197 
    198 def scattered_embedding_lookup(params,
    199                                values,
    200                                dimension,
    201                                name=None,
    202                                hash_key=None):
    203   """Looks up embeddings using parameter hashing for each value in `values`.
    204 
    205   The i-th embedding component of a value v in `values` is found by retrieving
    206   the weight whose index is a fingerprint of the pair (v,i).
    207   The concept is explored as "feature hashing" for model compression in this
    208   paper: http://arxiv.org/pdf/1504.04788.pdf
    209 
    210   Feature hashing has the pleasant effect of allowing us to compute an embedding
    211   without needing a pre-determined vocabulary, relieving some amount of process
    212   complexity. It also allows for us to maintain embeddings for possibly
    213   trillions of features with a fixed amount of memory.
    214 
    215   Note that this is superior to out-of-vocabulary shared "hash buckets" in that
    216   the embedding is extremely likely to be unique for each token as opposed to
    217   being shared across probably-colliding tokens. The price is that we must
    218   compute a hash once for each scalar in the token's embedding as opposed to
    219   once per token.
    220 
    221   If `params` is a list, it represents a partition of the embedding parameters.
    222   Each tensor in the list should have the same length, except for the first ones
    223   which may have an additional element. For instance 10 parameters can be
    224   partitioned in 4 tensors with length `[3, 3, 2, 2]`.
    225 
    226   Args:
    227     params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
    228       Each tensor must be of rank 1 with fully-defined shape.
    229     values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`.
    230     dimension: Embedding dimension.
    231     name: An optional name for this op.
    232     hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
    233       function to combine the crosses fingerprints on SparseFeatureCrossOp
    234       (optional).
    235 
    236   Returns:
    237     A `Tensor` with shape `[d0, ..., dn, dimension]`.
    238 
    239   Raises:
    240     ValueError: if dimension is not positive or the partition size is invalid.
    241   """
    242   if dimension is None:
    243     raise ValueError("You must specify dimension.")
    244   return _sampled_scattered_embedding_lookup(
    245       params, values, dimension=dimension, sampled_candidates=None,
    246       hash_key=hash_key, name=name)
    247 
    248 
    249 def _sampled_scattered_embedding_lookup(
    250     params, values, dimension=None, sampled_candidates=None, hash_key=None,
    251     name=None):
    252   """Looks up embeddings using parameter hashing for each value in `values`.
    253 
    254   This method looks up selected embedding dimensions if `sampled_candidates` is
    255   given, otherwise looks up all dimensions.
    256 
    257   The i-th embedding component of a value v in `values` is found by retrieving
    258   the weight whose index is a fingerprint of the pair (v,i).
    259   The concept is explored as "feature hashing" for model compression in this
    260   paper: http://arxiv.org/pdf/1504.04788.pdf
    261 
    262   Feature hashing has the pleasant effect of allowing us to compute an embedding
    263   without needing a pre-determined vocabulary, relieving some amount of process
    264   complexity. It also allows for us to maintain embeddings for possibly
    265   trillions of features with a fixed amount of memory.
    266 
    267   Note that this is superior to out-of-vocabulary shared "hash buckets" in that
    268   the embedding is extremely likely to be unique for each token as opposed to
    269   being shared across probably-colliding tokens. The price is that we must
    270   compute a hash once for each scalar in the token's embedding as opposed to
    271   once per token.
    272 
    273   If `params` is a list, it represents a partition of the embedding parameters.
    274   Each tensor in the list should have the same length, except for the first ones
    275   which may have an additional element. For instance 10 parameters can be
    276   partitioned in 4 tensors with length `[3, 3, 2, 2]`.
    277 
    278   Args:
    279     params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
    280       Each tensor must be of rank 1 with fully-defined shape.
    281     values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`.
    282     dimension: Embedding dimension. The user must specify either `dimension` or
    283       `sampled_candidates`.
    284     sampled_candidates: An optional `Tensor` of slice indices to keep along the
    285       final dimension with shape `[d0, ..., dn, N]`. If given, `dimension` is
    286       ignored. If `None`, looks up all candidates.
    287     hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
    288       function to combine the crosses fingerprints on SparseFeatureCrossOp
    289       (optional).
    290     name: An optional name for this op.
    291 
    292   Returns:
    293     A `Tensor` with shape `[d0, ..., dn, dimension]`.
    294     If `sampled_candidates` is given, the output shape is `[d0, ..., dn, N]`
    295 
    296   Raises:
    297     ValueError: if dimension is not positive or the partition size is invalid.
    298   """
    299   if isinstance(params, variables.PartitionedVariable):
    300     params = list(params)
    301   if not isinstance(params, list):
    302     params = [params]
    303 
    304   with ops.name_scope(name, "scattered_embedding_lookup",
    305                       params + [dimension, values]):
    306     # Flatten the values
    307     values_shape = array_ops.shape(values)
    308     values = array_ops.reshape(values, [-1, 1])
    309 
    310     if sampled_candidates is None:
    311       if dimension is None:
    312         raise ValueError(
    313             "You must specify either dimension or sampled_candidates.")
    314       if dimension <= 0:
    315         raise ValueError("Dimension must be >0. Given is %d" % dimension)
    316       sampled_candidates = array_ops.tile(array_ops.expand_dims(
    317           math_ops.range(0, dimension), 0), array_ops.shape(values))
    318     else:
    319       dimension = array_ops.shape(sampled_candidates)[
    320           math_ops.subtract(array_ops.rank(sampled_candidates), 1)]
    321       sampled_candidates_shape = array_ops.shape(sampled_candidates)
    322       dimension_tensor = array_ops.reshape(dimension, shape=[1,])
    323       expected_shape = array_ops.concat([values_shape, dimension_tensor], 0)
    324       with ops.control_dependencies([control_flow_ops.Assert(
    325           math_ops.reduce_all(math_ops.equal(sampled_candidates_shape,
    326                                              expected_shape)),
    327           ["The shape of sampled_candidates: ", sampled_candidates_shape,
    328            " does not match the shape of values: ", values_shape])]):
    329         # Flatten sampled_candidates, same way as values are flattened.
    330         sampled_candidates = array_ops.reshape(sampled_candidates,
    331                                                [-1, dimension])
    332 
    333     num_partitions = len(params)
    334     partition_sizes = []
    335     for p in range(num_partitions):
    336       shape = params[p].get_shape()
    337       shape.assert_has_rank(1)
    338       shape.assert_is_fully_defined()
    339       partition_sizes.append(shape[0].value)
    340     num_params = sum(partition_sizes)  # Total number of parameters.
    341 
    342     # Assert the size of each partition.
    343     for p in range(num_partitions):
    344       expected_size = (num_params - p - 1) // num_partitions + 1
    345       if partition_sizes[p] != expected_size:
    346         raise ValueError("Tensor %d in params has size %d, expected %d." %
    347                          (p, partition_sizes[p], expected_size))
    348 
    349     # With two values v1 and v2 and 3 dimensions, we will cross
    350     # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]].
    351     tensors_to_cross = [sampled_candidates, values]
    352     ids = sparse_feature_cross_op.sparse_feature_cross(
    353         tensors_to_cross, hashed_output=True, num_buckets=num_params,
    354         hash_key=hash_key)
    355     ids = sparse_ops.sparse_tensor_to_dense(ids)
    356 
    357     # No need to validate the indices since we have checked the params
    358     # dimensions and we know the largest id.
    359     result = embedding_ops.embedding_lookup(
    360         params, ids, partition_strategy="div")
    361 
    362     return array_ops.reshape(result,
    363                              array_ops.concat([values_shape, [dimension]], 0))
    364 
    365 
    366 def scattered_embedding_lookup_sparse(params,
    367                                       sparse_values,
    368                                       dimension,
    369                                       combiner=None,
    370                                       default_value=None,
    371                                       name=None,
    372                                       hash_key=None):
    373   """Looks up embeddings of a sparse feature using parameter hashing.
    374 
    375   See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing.
    376 
    377   Args:
    378     params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
    379       Each tensor must be of rank 1 with fully-defined shape.
    380     sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
    381       Some rows may be empty.
    382     dimension: Embedding dimension
    383     combiner: A string specifying how to combine embedding results for each
    384         entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
    385         the default.
    386     default_value: The value to use for an entry with no features.
    387     name: An optional name for this op.
    388     hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
    389       function to combine the crosses fingerprints on SparseFeatureCrossOp
    390       (optional).
    391 
    392   Returns:
    393      Dense tensor with shape [N, dimension] with N the number of rows in
    394        sparse_values.
    395 
    396   Raises:
    397     TypeError: If sparse_values is not a SparseTensor.
    398     ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
    399   """
    400   if combiner is None:
    401     logging.warn("The default value of combiner will change from \"mean\" "
    402                  "to \"sqrtn\" after 2016/11/01.")
    403     combiner = "mean"
    404   if isinstance(params, variables.PartitionedVariable):
    405     params = list(params)
    406   if not isinstance(params, list):
    407     params = [params]
    408   if not isinstance(sparse_values, sparse_tensor.SparseTensor):
    409     raise TypeError("sparse_values must be SparseTensor")
    410 
    411   with ops.name_scope(name, "scattered_embedding_lookup_sparse",
    412                       params + [sparse_values]) as scope:
    413     # Fill in the empty rows.
    414     if default_value is None:
    415       # Random default values to reduce the risk of collision.
    416       if sparse_values.dtype == dtypes.string:
    417         default_value = "6ZxWzWOHxZ"
    418       else:
    419         default_value = 1288896567
    420     sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
    421         sparse_values, default_value)
    422 
    423     segment_ids = sparse_values.indices[:, 0]
    424     if segment_ids.dtype != dtypes.int32:
    425       segment_ids = math_ops.cast(segment_ids, dtypes.int32)
    426 
    427     values = sparse_values.values
    428     values, idx = array_ops.unique(values)
    429 
    430     embeddings = scattered_embedding_lookup(
    431         params, values, dimension, hash_key=hash_key)
    432 
    433     if combiner == "sum":
    434       embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
    435                                                name=scope)
    436     elif combiner == "mean":
    437       embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
    438                                                 name=scope)
    439     elif combiner == "sqrtn":
    440       embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
    441                                                   name=scope)
    442     else:
    443       raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
    444 
    445     return embeddings
    446 
    447 
    448 def embedding_lookup_unique(params, ids, name=None):
    449   """Version of embedding_lookup that avoids duplicate lookups.
    450 
    451   This can save communication in the case of repeated ids.
    452   Same interface as embedding_lookup. Except it supports multi-dimensional `ids`
    453   which allows to not reshape input/output to fit gather.
    454 
    455   Args:
    456     params: A list of tensors with the same shape and type, or a
    457       `PartitionedVariable`. Shape `[index, d1, d2, ...]`.
    458     ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
    459       the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`.
    460     name: A name for this operation (optional).
    461 
    462   Returns:
    463     A `Tensor` with the same type as the tensors in `params` and dimension of
    464     `[ids1, ids2, d1, d2, ...]`.
    465 
    466   Raises:
    467     ValueError: If `params` is empty.
    468   """
    469   with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]):
    470     ids = ops.convert_to_tensor(ids)
    471     shape = array_ops.shape(ids)
    472     ids_flat = array_ops.reshape(
    473         ids, math_ops.reduce_prod(shape, keep_dims=True))
    474     unique_ids, idx = array_ops.unique(ids_flat)
    475     unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
    476     embeds_flat = array_ops.gather(unique_embeddings, idx)
    477     embed_shape = array_ops.concat(
    478         [shape, array_ops.shape(unique_embeddings)[1:]], 0)
    479     embeds = array_ops.reshape(embeds_flat, embed_shape)
    480     embeds.set_shape(ids.get_shape().concatenate(
    481         unique_embeddings.get_shape()[1:]))
    482     return embeds
    483 
    484 
    485 def _sampled_scattered_embedding_lookup_sparse(params,
    486                                                sp_values,
    487                                                dimension=None,
    488                                                sampled_candidates=None,
    489                                                hash_key=None,
    490                                                with_sign_hash=False,
    491                                                name=None):
    492   """Looks up embeddings using parameter hashing for sparse values.
    493 
    494   This method looks up selected embedding dimensions if `sampled_candidates` is
    495   given, otherwise looks up all dimensions.
    496 
    497   The i-th embedding component of a value v in `values` is found by retrieving
    498   the weight whose index is a fingerprint of the pair (v,i).
    499   The concept is explored as "feature hashing" for model compression in this
    500   paper: http://arxiv.org/pdf/1504.04788.pdf
    501 
    502   This is logically equivalent to:
    503   * Transforming `sp_values` (which has shape `[d0, d1]`) into a one-hot
    504     `Tensor` of shape `[d0, N]`.
    505   * Multiplying with a `Tensor` `h` of shape `[N, dimension]`, where
    506     `h(i, j) = params[hash(i, j)]`.
    507 
    508   Args:
    509     params: A float `Tensor` with rank 1 and fully-defined shape.
    510     sp_values: A 2D `SparseTensor` to be embedded with shape `[d0, d1]`.
    511     dimension: An int `Tensor` of the final dimension. The user needs to provide
    512       either `dimension` or `sampled_candidates`.
    513     sampled_candidates: An optional `Tensor` of column indices to keep along
    514       the final dimension with shape `[d0, N]`. If given, `dimension` is
    515       ignored. If `None`, looks up all candidates.
    516     hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
    517       function to combine the crosses fingerprints on SparseFeatureCrossOp
    518       (optional).
    519     with_sign_hash:  A `bool` indicating whether `h(i, j)` should be multiplied
    520       by `+1` or `-1`, where the value selected is determined by hashing
    521       `(i, j)`. This is often necessary to remove bias resulting from hash
    522       collisions.
    523     name: An optional name for this op.
    524 
    525   Returns:
    526     A `Tensor` of shape `[d0, dimension]`.
    527     If `sampled_candidates` is given, the output shape is `[d0, N]`.
    528 
    529   Raises:
    530     TypeError: If sp_values is not `SparseTensor`.
    531     ValueError: If both `dimension` and `sampled_candidates` are `None`.
    532   """
    533   if not isinstance(sp_values, sparse_tensor.SparseTensor):
    534     raise TypeError("sp_values must be SparseTensor")
    535 
    536   with ops.name_scope(
    537       name=name,
    538       default_name="sampled_scattered_embedding_lookup_sparse",
    539       values=[sp_values, params, dimension, sampled_candidates]) as name_scope:
    540     segment_ids = sp_values.indices[:, 0]
    541     if sampled_candidates is not None:
    542       # Tile sampled_candidates so there is one line corresponding to each
    543       # element in sp_values.values
    544       sampled_candidates = array_ops.gather(sampled_candidates, segment_ids)
    545 
    546     embeddings = _sampled_scattered_embedding_lookup(
    547         params, sp_values.values, dimension=dimension,
    548         sampled_candidates=sampled_candidates,
    549         hash_key=hash_key, name="values_lookup")
    550     if with_sign_hash:
    551       signs = _sampled_scattered_embedding_lookup(
    552           array_ops.constant([-1., 1.]), sp_values.values, dimension=dimension,
    553           sampled_candidates=sampled_candidates, hash_key=hash_key,
    554           name="signs_lookup")
    555       embeddings = math_ops.multiply(signs, embeddings, name="signs_hash")
    556 
    557     if segment_ids.dtype != dtypes.int32:
    558       segment_ids = math_ops.cast(segment_ids, dtypes.int32)
    559     num_segments = array_ops.shape(sp_values)[0]
    560 
    561     return math_ops.unsorted_segment_sum(embeddings, segment_ids,
    562                                          num_segments=num_segments,
    563                                          name=name_scope)
    564 
    565 
    566 def embedding_lookup_sparse_with_distributed_aggregation(
    567     params,
    568     sp_ids,
    569     sp_weights,
    570     partition_strategy="mod",
    571     name=None,
    572     combiner=None,
    573     max_norm=None):
    574   """Computes embeddings for the given ids and weights.
    575 
    576   Embeddings belonging to same param are aggregated on that device first. This
    577   op is intended to decrease data transmission and improve parallelism. See
    578   `tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
    579 
    580   Args:
    581     params: A single tensor representing the complete embedding tensor,
    582       or a list of P tensors all of same shape except for the first dimension,
    583       representing sharded embedding tensors.  Alternatively, a
    584       `PartitionedVariable`, created by partitioning along dimension 0. Each
    585       element must be appropriately sized for the given `partition_strategy`.
    586     sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
    587       where N is typically batch size and M is arbitrary.
    588     sp_weights: either a SparseTensor of float / double weights, or None to
    589       indicate all weights should be taken to be 1. If specified, sp_weights
    590       must have exactly the same shape and indices as sp_ids.
    591     partition_strategy: A string specifying the partitioning strategy, relevant
    592       if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
    593       is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    594     name: Optional name for the op.
    595     combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
    596       and "sum" are supported.
    597       "sum" computes the weighted sum of the embedding results for each row.
    598       "mean" is the weighted sum divided by the total weight.
    599       "sqrtn" is the weighted sum divided by the square root of the sum of the
    600       squares of the weights.
    601     max_norm: If not None, each embedding is normalized to have l2 norm equal
    602       to max_norm before combining.
    603 
    604   Returns:
    605     A dense tensor representing the combined embeddings for the
    606     sparse ids. For each row in the dense tensor represented by sp_ids, the op
    607     looks up the embeddings for all ids in that row, multiplies them by the
    608     corresponding weight, and combines these embeddings as specified.
    609 
    610   Raises:
    611     TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
    612       None nor SparseTensor.
    613     ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
    614   """
    615   if combiner is None:
    616     logging.warn("The default value of combiner will change from \"mean\" "
    617                  "to \"sqrtn\" after 2016/11/01.")
    618     combiner = "mean"
    619   if combiner not in ("mean", "sqrtn", "sum"):
    620     raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
    621   if isinstance(params, variables.PartitionedVariable):
    622     params = list(params)  # Iterate to get the underlying Variables.
    623   if not isinstance(params, list):
    624     params = [params]
    625   if not isinstance(sp_ids, sparse_tensor.SparseTensor):
    626     raise TypeError("sp_ids must be SparseTensor")
    627   ignore_weights = sp_weights is None
    628   if not ignore_weights:
    629     if not isinstance(sp_weights, sparse_tensor.SparseTensor):
    630       raise TypeError("sp_weights must be either None or SparseTensor")
    631     sp_ids.values.get_shape().assert_is_compatible_with(
    632         sp_weights.values.get_shape())
    633     sp_ids.indices.get_shape().assert_is_compatible_with(
    634         sp_weights.indices.get_shape())
    635     sp_ids.dense_shape.get_shape().assert_is_compatible_with(
    636         sp_weights.dense_shape.get_shape())
    637     # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
    638     # sp_weights have equal indices and shapes.
    639 
    640   with ops.name_scope(name, "embedding_lookup_sparse",
    641                       params + [sp_ids]) as name:
    642     segment_ids = sp_ids.indices[:, 0]
    643     if segment_ids.dtype != dtypes.int32:
    644       segment_ids = math_ops.cast(segment_ids, dtypes.int32)
    645 
    646     ids = sp_ids.values
    647     if ignore_weights:
    648       ids, idx = array_ops.unique(ids)
    649     else:
    650       idx = None
    651 
    652     weights = None if ignore_weights else sp_weights.values
    653     embeddings = _embedding_lookup_with_distributed_aggregation(
    654         params,
    655         ids,
    656         partition_strategy=partition_strategy,
    657         max_norm=max_norm,
    658         weights=weights,
    659         idx=idx,
    660         segment_ids=segment_ids)
    661     # Set weights to all one if ignore weights.
    662     if ignore_weights:
    663       weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
    664     if weights.dtype != embeddings.dtype:
    665       weights = math_ops.cast(weights, embeddings.dtype)
    666     # Reshape weights.
    667     ones = array_ops.fill(
    668         array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
    669     bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
    670     orig_weights_shape = weights.get_shape()
    671     weights = array_ops.reshape(weights, bcast_weights_shape)
    672     if embeddings.get_shape().ndims is not None:
    673       weights.set_shape(
    674           orig_weights_shape.concatenate(
    675               [1 for _ in range(embeddings.get_shape().ndims - 1)]))
    676 
    677     if combiner == "mean":
    678       weight_sum = math_ops.segment_sum(weights, segment_ids)
    679       embeddings = math_ops.div(embeddings, weight_sum)
    680     elif combiner == "sqrtn":
    681       weights_squared = math_ops.pow(weights, 2)
    682       weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
    683       weight_sum_sqrt = math_ops.sqrt(weight_sum)
    684       embeddings = math_ops.div(embeddings, weight_sum_sqrt)
    685     elif combiner != "sum":
    686       assert False, "Unrecognized combiner"
    687     return embeddings
    688 
    689 
    690 def _do_gather(params, ids, name=None):
    691   """Deals with doing gather differently for resource variables."""
    692   if isinstance(params, resource_variable_ops.ResourceVariable):
    693     return params.sparse_read(ids, name=name)
    694   return array_ops.gather(params, ids, name=name)
    695 
    696 
    697 def _embedding_lookup_with_distributed_aggregation(params,
    698                                                    ids,
    699                                                    partition_strategy="mod",
    700                                                    name=None,
    701                                                    max_norm=None,
    702                                                    weights=None,
    703                                                    idx=None,
    704                                                    segment_ids=None):
    705   """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
    706   if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
    707     raise ValueError("Need at least one param")
    708   if isinstance(params, variables.PartitionedVariable):
    709     params = list(params)  # Iterate to get the underlying Variables.
    710   if not isinstance(params, list):
    711     params = [params]
    712 
    713   def maybe_normalize(x):
    714     if max_norm is not None:
    715       if x.get_shape().ndims is not None:
    716         ndims = x.get_shape().ndims
    717       else:
    718         ndims = array_ops.size(array_ops.shape(x))
    719       return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
    720     return x
    721 
    722   with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
    723                       params + [ids]) as name:
    724     np = len(params)  # Number of partitions
    725     # Preserve the resource variable status to avoid accidental dense reads.
    726     if not any(
    727         isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
    728       params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    729     if np == 1:
    730       with ops.colocate_with(params[0]):
    731         ret = maybe_normalize(_do_gather(params[0], ids))
    732         ignore_weights = weights is None
    733         if not ignore_weights:
    734           if weights.dtype != ret.dtype:
    735             weights = math_ops.cast(weights, ret.dtype)
    736           # Reshape to allow broadcast
    737           ones = array_ops.fill(
    738               array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
    739           bcast_weights_shape = array_ops.concat(
    740               [array_ops.shape(weights), ones], 0)
    741           orig_weights_shape = weights.get_shape()
    742           weights = array_ops.reshape(weights, bcast_weights_shape)
    743           # Set weights shape after reshape
    744           if ret.get_shape().ndims is not None:
    745             weights.set_shape(
    746                 orig_weights_shape.concatenate(
    747                     [1 for _ in range(ret.get_shape().ndims - 1)]))
    748           ret *= weights
    749           return math_ops.segment_sum(ret, segment_ids, name=name)
    750         else:
    751           return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
    752     else:
    753       ids = ops.convert_to_tensor(ids, name="ids")
    754       flat_ids = array_ops.reshape(ids, [-1])
    755       original_indices = math_ops.range(array_ops.size(flat_ids))
    756 
    757       # Create p_assignments and set new_ids depending on the strategy.
    758       if partition_strategy == "mod":
    759         p_assignments = flat_ids % np
    760         new_ids = flat_ids // np
    761       elif partition_strategy == "div":
    762         # Compute num_total_ids as the sum of dim-0 of params, then assign to
    763         # partitions based on a constant number of ids per partition. Optimize
    764         # if we already know the full shape statically.
    765         dim_0_size = params[0].get_shape()[0]
    766         for p in xrange(1, np):
    767           dim_0_size += params[p].get_shape()[0]
    768         if dim_0_size.value:
    769           num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
    770         else:
    771           dim_0_sizes = []
    772           for p in xrange(np):
    773             if params[p].get_shape()[0].value is not None:
    774               dim_0_sizes.append(params[p].get_shape()[0].value)
    775             else:
    776               with ops.colocate_with(params[p]):
    777                 dim_0_sizes.append(array_ops.shape(params[p])[0])
    778           num_total_ids = math_ops.reduce_sum(
    779               math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
    780         ids_per_partition = num_total_ids // np
    781         extras = num_total_ids % np
    782 
    783         p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
    784             flat_ids - extras) // ids_per_partition)
    785 
    786         # Emulate a conditional using a boolean indicator tensor
    787         is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
    788                                                       flat_ids.dtype)
    789         new_ids = (is_in_first_extras_partitions * (flat_ids %
    790                                                     (ids_per_partition + 1)) +
    791                    (1 - is_in_first_extras_partitions) * (
    792                        (flat_ids - extras) % ids_per_partition))
    793       else:
    794         raise ValueError("Unrecognized partition strategy: " +
    795                          partition_strategy)
    796 
    797       # Cast partition assignments to int32 for use in dynamic_partition.
    798       # There really should not be more than 2^32 partitions.
    799       p_assignments = math_ops.cast(p_assignments, dtypes.int32)
    800       # Partition list of ids based on assignments into np separate lists
    801       gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
    802       # Similarly, partition the original indices.
    803       pindices = data_flow_ops.dynamic_partition(original_indices,
    804                                                  p_assignments, np)
    805       # Do np separate lookups, finding embeddings for plist[p] in params[p]
    806       partitioned_result = []
    807       for p in xrange(np):
    808         with ops.colocate_with(params[p]):
    809           partitioned_result.append(_do_gather(params[p], gather_ids[p]))
    810 
    811       ignore_weights = weights is None
    812       if not ignore_weights:
    813         # Partition weights according to pindices.
    814         partitioned_weight = []
    815         for p in xrange(np):
    816           partitioned_weight.append(array_ops.gather(weights, pindices[p]))
    817       # Reshape each partition result.
    818       element_shape = params[0].get_shape()[1:]
    819       for p in params[1:]:
    820         element_shape = element_shape.merge_with(p.get_shape()[1:])
    821       if element_shape.is_fully_defined():
    822         for p in xrange(np):
    823           with ops.colocate_with(params[p]):
    824             partitioned_result[p] = array_ops.reshape(
    825                 partitioned_result[p],
    826                 array_ops.concat([array_ops.shape(pindices[p]), element_shape],
    827                                  0))
    828       else:
    829         with ops.colocate_with(params[0]):
    830           params_shape = array_ops.shape(params[0])
    831         for p in xrange(np):
    832           with ops.colocate_with(params[p]):
    833             partitioned_result[p] = array_ops.reshape(
    834                 partitioned_result[p],
    835                 array_ops.concat([
    836                     array_ops.shape(pindices[p]), array_ops.slice(
    837                         params_shape, [1], [-1])
    838                 ], 0))
    839       # Normalize each partition result.
    840       for p in xrange(np):
    841         with ops.colocate_with(params[p]):
    842           partitioned_result[p] = maybe_normalize(partitioned_result[p])
    843       if not ignore_weights:
    844         # Multiply each partition result with partition weights.
    845         for p in xrange(np):
    846           with ops.colocate_with(params[p]):
    847             if partitioned_weight[p].dtype != partitioned_result[p].dtype:
    848               partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
    849                                                     partitioned_result[p].dtype)
    850             # Reshape partition weights.
    851             ones = array_ops.fill(
    852                 array_ops.expand_dims(
    853                     array_ops.rank(partitioned_result[p]) - 1, 0), 1)
    854             bcast_weights_shape = array_ops.concat(
    855                 [array_ops.shape(partitioned_weight[p]), ones], 0)
    856             orig_weights_shape = partitioned_weight[p].get_shape()
    857             partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
    858                                                       bcast_weights_shape)
    859             if partitioned_result[p].get_shape().ndims is not None:
    860               partitioned_weight[p].set_shape(
    861                   orig_weights_shape.concatenate([
    862                       1
    863                       for _ in range(partitioned_result[p].get_shape().ndims -
    864                                      1)
    865                   ]))
    866             partitioned_result[p] *= partitioned_weight[p]
    867       partitioned_segment_ids = []
    868       for p in xrange(np):
    869         if not ignore_weights:
    870           # Partition segment_ids according to pindices.
    871           p_segment_ids = array_ops.gather(segment_ids, pindices[p])
    872           # Number the p_segment_ids to meet segment_sum's requirements. Note
    873           # that unique_p_segment_ids contains unique segment ids of this
    874           # partition and these ids' order is unchanged.
    875           unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
    876               p_segment_ids)
    877           partitioned_segment_ids.append(unique_p_segment_ids)
    878           # segment_sum this partition's result.
    879           with ops.colocate_with(params[p]):
    880             partitioned_result[p] = math_ops.segment_sum(
    881                 partitioned_result[p], unique_p_segment_idx)
    882         else:
    883           # When ignore weights, we need to get indexs of elements in idx and
    884           # segment_ids.
    885           _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
    886           all_idx = math_ops.range(array_ops.shape(idx)[0])
    887           _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
    888           # Gather segment_ids and idx according to indexs.
    889           p_segment_ids = array_ops.gather(segment_ids, include_idx)
    890           p_idx = array_ops.gather(idx, include_idx)
    891           # Number the p_segment_ids, same as ignore_weights case above.
    892           unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
    893               p_segment_ids)
    894           _, unique_p_idx_idx = array_ops.unique(p_idx)
    895           partitioned_segment_ids.append(unique_p_segment_ids)
    896           with ops.colocate_with(params[p]):
    897             partitioned_result[p] = math_ops.sparse_segment_sum(
    898                 partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
    899       # Concat each partition's segment_ids and result for final segment_sum.
    900       concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
    901       concat_partitioned_result = array_ops.concat(partitioned_result, 0)
    902       return math_ops.unsorted_segment_sum(
    903           concat_partitioned_result,
    904           concat_segment_ids,
    905           math_ops.reduce_max(concat_segment_ids) + 1,
    906           name=name)
    907