Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Operations for embeddings."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from six.moves import xrange  # pylint: disable=redefined-builtin
     21 
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import sparse_tensor
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import clip_ops
     28 # Imports gradient definitions.
     29 from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
     30 from tensorflow.python.ops import data_flow_ops
     31 from tensorflow.python.ops import math_ops
     32 from tensorflow.python.ops import resource_variable_ops
     33 from tensorflow.python.ops import variables
     34 from tensorflow.python.platform import tf_logging as logging
     35 from tensorflow.python.util.tf_export import tf_export
     36 
     37 
     38 def _gather(params, ids, name=None):
     39   """Helper function for _embedding_lookup_and_transform.
     40 
     41   This function gathers embeddings from a single tensor. The gather deals with
     42   resource variables specially.
     43 
     44   Args:
     45     params: A `Tensor` of embeddings.
     46     ids: A `Tensor` indexing the embeddings to be retrieved from `params`.
     47     name: A name for the operation (optional).
     48 
     49   Returns:
     50     A `Tensor` with the same type as `params`.
     51   """
     52   if isinstance(params, resource_variable_ops.ResourceVariable):
     53     return params.sparse_read(ids, name=name)
     54   else:
     55     return array_ops.gather(params, ids, name=name)
     56 
     57 
     58 def _clip(params, ids, max_norm):
     59   """Helper function for _embedding_lookup_and_transform.
     60 
     61   This function optionally clips embeddings to an l2-norm of max_norm.
     62 
     63   Args:
     64     params: A `Tensor` of embeddings retrieved by `_gather`.
     65     ids: The `ids` argument that was passed to `_gather`.
     66     max_norm: If provided, the embeddings are l2-normalized to the value of
     67       max_norm.
     68 
     69   Returns:
     70     A `Tensor` with the same type as `params`.
     71   """
     72 
     73   def _rank(x):
     74     """Helper function to retrieve the rank of a tensor.
     75 
     76     Args:
     77       x: Something convertible to `Tensor`.
     78 
     79     Returns:
     80       Either a pair `(rank, True)` where `rank` is an integer or a pair
     81       `(rank, False)` where `rank` is an integer `Tensor`. In either case,
     82       `rank` is the rank of `x`.
     83     """
     84     rank = ops.convert_to_tensor(x).get_shape().ndims
     85     if rank:
     86       return rank, True
     87     else:
     88       return array_ops.rank(x), False
     89 
     90   if max_norm is None:
     91     return params
     92   ids_rank, ids_static = _rank(ids)
     93   params_rank, params_static = _rank(params)
     94   return clip_ops.clip_by_norm(
     95       params,
     96       max_norm,
     97       axes=(list(range(ids_rank, params_rank))
     98             if ids_static and params_static
     99             else math_ops.range(ids_rank, params_rank)))
    100 
    101 
    102 def _embedding_lookup_and_transform(params,
    103                                     ids,
    104                                     partition_strategy="mod",
    105                                     name=None,
    106                                     max_norm=None,
    107                                     transform_fn=None):
    108   """Helper function for embedding_lookup and _compute_sampled_logits.
    109 
    110   This function is a generalization of embedding_lookup that optionally
    111   applies a caller-specified transformation to each embedding. This is
    112   done through the `transform_fn` argument. If provided, the function is
    113   applied to each partitioned tensor of retrieved embeddings, colocated
    114   with the embeddings. This function will be called with a single `Tensor`
    115   argument of the same type as the `params` tensor and should return a
    116   `Tensor`. The shape of the argument will be the same as `params` except
    117   for the size of the first dimension. The first dimension of the result's
    118   shape must be the same size as the argument's.
    119 
    120   Args:
    121     params: See embedding_lookup.
    122     ids: See embedding_lookup.
    123     partition_strategy: See embedding_lookup.
    124     name: See embedding_lookup.
    125     max_norm: See embedding_lookup.
    126     transform_fn: An optional function to apply to each retrieved embedding.
    127       If max_norm is provided, transform_fn is applied to the norm-limited
    128       embeddings.
    129 
    130   Returns:
    131     See embedding_lookup for details.
    132   Raises:
    133     ValueError: If `params` is empty.
    134   """
    135   if params is None or params in ((), []):
    136     raise ValueError("Need at least one param")
    137   if isinstance(params, variables.PartitionedVariable):
    138     params = list(params)  # Iterate to get the underlying Variables.
    139   if not isinstance(params, list):
    140     params = [params]
    141 
    142   with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    143     np = len(params)  # Number of partitions
    144     # Preserve the resource variable status to avoid accidental dense reads.
    145     if not any(
    146         isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
    147       params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    148     ids = ops.convert_to_tensor(ids, name="ids")
    149     if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
    150       with ops.colocate_with(params[0]):
    151         result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
    152         if transform_fn:
    153           result = transform_fn(result)
    154         return result
    155     else:
    156       # Flatten the ids. There are two cases where we need to do this.
    157       # - There is more than one params tensor.
    158       # - There is a transform_fn and ids is not statically known to be 1-D.
    159       #   We must flatten in this case because transform_fn expects a flat
    160       #   tensor of embeddings.
    161       flat_ids = array_ops.reshape(ids, [-1])
    162       original_indices = math_ops.range(array_ops.size(flat_ids))
    163 
    164       # Create p_assignments and set new_ids depending on the strategy.
    165       if partition_strategy == "mod":
    166         p_assignments = flat_ids % np
    167         new_ids = flat_ids // np
    168       elif partition_strategy == "div":
    169         # Compute num_total_ids as the sum of dim-0 of params, then assign to
    170         # partitions based on a constant number of ids per partition. Optimize
    171         # if we already know the full shape statically.
    172         dim_0_size = params[0].get_shape()[0]
    173         for p in xrange(1, np):
    174           dim_0_size += params[p].get_shape()[0]
    175         if dim_0_size.value:
    176           num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
    177         else:
    178           dim_0_sizes = []
    179           for p in xrange(np):
    180             if params[p].get_shape()[0].value is not None:
    181               dim_0_sizes.append(params[p].get_shape()[0].value)
    182             else:
    183               with ops.colocate_with(params[p]):
    184                 dim_0_sizes.append(array_ops.shape(params[p])[0])
    185           num_total_ids = math_ops.reduce_sum(
    186               math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
    187         ids_per_partition = num_total_ids // np
    188         extras = num_total_ids % np
    189 
    190         p_assignments = math_ops.maximum(
    191             flat_ids // (ids_per_partition + 1),
    192             (flat_ids - extras) // ids_per_partition)
    193 
    194         # Emulate a conditional using a boolean indicator tensor
    195         new_ids = array_ops.where(p_assignments < extras,
    196                                   flat_ids % (ids_per_partition + 1),
    197                                   (flat_ids - extras) % ids_per_partition)
    198       else:
    199         raise ValueError("Unrecognized partition strategy: " +
    200                          partition_strategy)
    201 
    202       # Cast partition assignments to int32 for use in dynamic_partition.
    203       # There really should not be more than 2^32 partitions.
    204       p_assignments = math_ops.cast(p_assignments, dtypes.int32)
    205       # Partition list of ids based on assignments into np separate lists
    206       gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
    207       # Similarly, partition the original indices.
    208       pindices = data_flow_ops.dynamic_partition(original_indices,
    209                                                  p_assignments, np)
    210       # Do np separate lookups, finding embeddings for plist[p] in params[p]
    211       partitioned_result = []
    212       for p in xrange(np):
    213         pids = gather_ids[p]
    214         with ops.colocate_with(params[p]):
    215           result = _gather(params[p], pids)
    216           if transform_fn:
    217             # If transform_fn is provided, the clip_by_norm precedes
    218             # the transform and hence must be co-located. See below
    219             # for the counterpart if transform_fn is not proveded.
    220             result = transform_fn(_clip(result, pids, max_norm))
    221         partitioned_result.append(result)
    222       # Stitch these back together
    223       ret = data_flow_ops.parallel_dynamic_stitch(
    224           pindices, partitioned_result, name=name)
    225 
    226       # Determine the static element shape.
    227       if transform_fn is None:
    228         element_shape_s = params[0].get_shape()[1:]
    229         for p in params[1:]:
    230           element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
    231       else:
    232         element_shape_s = ret.get_shape()[1:]
    233 
    234       # Compute the dynamic element shape.
    235       if element_shape_s.is_fully_defined():
    236         element_shape_d = element_shape_s
    237       elif transform_fn is None:
    238         # It's important that we compute params[0].shape on the right device
    239         # to avoid data motion.
    240         with ops.colocate_with(params[0]):
    241           params_shape = array_ops.shape(params[0])
    242         element_shape_d = params_shape[1:]
    243       else:
    244         element_shape_d = array_ops.shape(ret)[1:]
    245 
    246       # Reshape to reverse the flattening of ids.
    247       ret = array_ops.reshape(ret,
    248                               array_ops.concat(
    249                                   [array_ops.shape(ids), element_shape_d], 0))
    250 
    251       # Normally the reshape is sufficient, but setting shape explicitly
    252       # teaches shape inference that params[1:].get_shape() matters
    253       # (in the case that transform_fn is None).
    254       ret.set_shape(ids.get_shape().concatenate(element_shape_s))
    255       if not transform_fn:
    256         # If transform_fn was provided, the clip_by_norm was done above.
    257         ret = _clip(ret, ids, max_norm)
    258       return ret
    259 
    260 
    261 @tf_export("nn.embedding_lookup")
    262 def embedding_lookup(
    263     params,
    264     ids,
    265     partition_strategy="mod",
    266     name=None,
    267     validate_indices=True,  # pylint: disable=unused-argument
    268     max_norm=None):
    269   """Looks up `ids` in a list of embedding tensors.
    270 
    271   This function is used to perform parallel lookups on the list of
    272   tensors in `params`.  It is a generalization of
    273   @{tf.gather}, where `params` is
    274   interpreted as a partitioning of a large embedding tensor.  `params` may be
    275   a `PartitionedVariable` as returned by using `tf.get_variable()` with a
    276   partitioner.
    277 
    278   If `len(params) > 1`, each element `id` of `ids` is partitioned between
    279   the elements of `params` according to the `partition_strategy`.
    280   In all strategies, if the id space does not evenly divide the number of
    281   partitions, each of the first `(max_id + 1) % len(params)` partitions will
    282   be assigned one more id.
    283 
    284   If `partition_strategy` is `"mod"`, we assign each id to partition
    285   `p = id % len(params)`. For instance,
    286   13 ids are split across 5 partitions as:
    287   `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
    288 
    289   If `partition_strategy` is `"div"`, we assign ids to partitions in a
    290   contiguous manner. In this case, 13 ids are split across 5 partitions as:
    291   `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
    292 
    293   The results of the lookup are concatenated into a dense
    294   tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
    295 
    296   Args:
    297     params: A single tensor representing the complete embedding tensor,
    298       or a list of P tensors all of same shape except for the first dimension,
    299       representing sharded embedding tensors.  Alternatively, a
    300       `PartitionedVariable`, created by partitioning along dimension 0. Each
    301       element must be appropriately sized for the given `partition_strategy`.
    302     ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
    303       up in `params`.
    304     partition_strategy: A string specifying the partitioning strategy, relevant
    305       if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
    306       is `"mod"`.
    307     name: A name for the operation (optional).
    308     validate_indices: DEPRECATED. If this operation is assigned to CPU, values
    309       in `indices` are always validated to be within range.  If assigned to GPU,
    310       out-of-bound indices result in safe but unspecified behavior, which may
    311       include raising an error.
    312     max_norm: If provided, embedding values are l2-normalized to the value of
    313       max_norm.
    314 
    315   Returns:
    316     A `Tensor` with the same type as the tensors in `params`.
    317 
    318   Raises:
    319     ValueError: If `params` is empty.
    320   """
    321   return _embedding_lookup_and_transform(
    322       params=params,
    323       ids=ids,
    324       partition_strategy=partition_strategy,
    325       name=name,
    326       max_norm=max_norm,
    327       transform_fn=None)
    328 
    329 
    330 @tf_export("nn.embedding_lookup_sparse")
    331 def embedding_lookup_sparse(params,
    332                             sp_ids,
    333                             sp_weights,
    334                             partition_strategy="mod",
    335                             name=None,
    336                             combiner=None,
    337                             max_norm=None):
    338   """Computes embeddings for the given ids and weights.
    339 
    340   This op assumes that there is at least one id for each row in the dense tensor
    341   represented by sp_ids (i.e. there are no rows with empty features), and that
    342   all the indices of sp_ids are in canonical row-major order.
    343 
    344   It also assumes that all id values lie in the range [0, p0), where p0
    345   is the sum of the size of params along dimension 0.
    346 
    347   Args:
    348     params: A single tensor representing the complete embedding tensor,
    349       or a list of P tensors all of same shape except for the first dimension,
    350       representing sharded embedding tensors.  Alternatively, a
    351       `PartitionedVariable`, created by partitioning along dimension 0. Each
    352       element must be appropriately sized for the given `partition_strategy`.
    353     sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
    354       where N is typically batch size and M is arbitrary.
    355     sp_weights: either a SparseTensor of float / double weights, or None to
    356       indicate all weights should be taken to be 1. If specified, sp_weights
    357       must have exactly the same shape and indices as sp_ids.
    358     partition_strategy: A string specifying the partitioning strategy, relevant
    359       if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
    360       is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    361     name: Optional name for the op.
    362     combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
    363       and "sum" are supported.
    364       "sum" computes the weighted sum of the embedding results for each row.
    365       "mean" is the weighted sum divided by the total weight.
    366       "sqrtn" is the weighted sum divided by the square root of the sum of the
    367       squares of the weights.
    368     max_norm: If provided, each embedding is normalized to have l2 norm equal
    369       to max_norm before combining.
    370 
    371   Returns:
    372     A dense tensor representing the combined embeddings for the
    373     sparse ids. For each row in the dense tensor represented by sp_ids, the op
    374     looks up the embeddings for all ids in that row, multiplies them by the
    375     corresponding weight, and combines these embeddings as specified.
    376 
    377     In other words, if
    378 
    379       shape(combined params) = [p0, p1, ..., pm]
    380 
    381     and
    382 
    383       shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
    384 
    385     then
    386 
    387       shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
    388 
    389     For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
    390 
    391       [0, 0]: id 1, weight 2.0
    392       [0, 1]: id 3, weight 0.5
    393       [1, 0]: id 0, weight 1.0
    394       [2, 3]: id 1, weight 3.0
    395 
    396     with `combiner`="mean", then the output will be a 3x20 matrix where
    397 
    398       output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
    399       output[1, :] = params[0, :] * 1.0
    400       output[2, :] = params[1, :] * 3.0
    401 
    402   Raises:
    403     TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
    404       None nor SparseTensor.
    405     ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
    406   """
    407   if combiner is None:
    408     logging.warn("The default value of combiner will change from \"mean\" "
    409                  "to \"sqrtn\" after 2016/11/01.")
    410     combiner = "mean"
    411   if combiner not in ("mean", "sqrtn", "sum"):
    412     raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
    413   if isinstance(params, variables.PartitionedVariable):
    414     params = list(params)  # Iterate to get the underlying Variables.
    415   if not isinstance(params, list):
    416     params = [params]
    417   if not isinstance(sp_ids, sparse_tensor.SparseTensor):
    418     raise TypeError("sp_ids must be SparseTensor")
    419   ignore_weights = sp_weights is None
    420   if not ignore_weights:
    421     if not isinstance(sp_weights, sparse_tensor.SparseTensor):
    422       raise TypeError("sp_weights must be either None or SparseTensor")
    423     sp_ids.values.get_shape().assert_is_compatible_with(
    424         sp_weights.values.get_shape())
    425     sp_ids.indices.get_shape().assert_is_compatible_with(
    426         sp_weights.indices.get_shape())
    427     sp_ids.dense_shape.get_shape().assert_is_compatible_with(
    428         sp_weights.dense_shape.get_shape())
    429     # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
    430     # sp_weights have equal indices and shapes.
    431 
    432   with ops.name_scope(name, "embedding_lookup_sparse",
    433                       params + [sp_ids]) as name:
    434     segment_ids = sp_ids.indices[:, 0]
    435     if segment_ids.dtype != dtypes.int32:
    436       segment_ids = math_ops.cast(segment_ids, dtypes.int32)
    437 
    438     ids = sp_ids.values
    439     if ignore_weights:
    440       ids, idx = array_ops.unique(ids)
    441     else:
    442       idx = None
    443 
    444     embeddings = embedding_lookup(
    445         params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
    446     if not ignore_weights:
    447       weights = sp_weights.values
    448       if weights.dtype != embeddings.dtype:
    449         weights = math_ops.cast(weights, embeddings.dtype)
    450 
    451       # Reshape weights to allow broadcast
    452       ones = array_ops.fill(
    453           array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
    454       bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
    455                                              0)
    456 
    457       orig_weights_shape = weights.get_shape()
    458       weights = array_ops.reshape(weights, bcast_weights_shape)
    459 
    460       # Set the weight shape, since after reshaping to bcast_weights_shape,
    461       # the shape becomes None.
    462       if embeddings.get_shape().ndims is not None:
    463         weights.set_shape(
    464             orig_weights_shape.concatenate(
    465                 [1 for _ in range(embeddings.get_shape().ndims - 1)]))
    466 
    467       embeddings *= weights
    468 
    469       if combiner == "sum":
    470         embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
    471       elif combiner == "mean":
    472         embeddings = math_ops.segment_sum(embeddings, segment_ids)
    473         weight_sum = math_ops.segment_sum(weights, segment_ids)
    474         embeddings = math_ops.div(embeddings, weight_sum, name=name)
    475       elif combiner == "sqrtn":
    476         embeddings = math_ops.segment_sum(embeddings, segment_ids)
    477         weights_squared = math_ops.pow(weights, 2)
    478         weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
    479         weight_sum_sqrt = math_ops.sqrt(weight_sum)
    480         embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
    481       else:
    482         assert False, "Unrecognized combiner"
    483     else:
    484       assert idx is not None
    485       if combiner == "sum":
    486         embeddings = math_ops.sparse_segment_sum(
    487             embeddings, idx, segment_ids, name=name)
    488       elif combiner == "mean":
    489         embeddings = math_ops.sparse_segment_mean(
    490             embeddings, idx, segment_ids, name=name)
    491       elif combiner == "sqrtn":
    492         embeddings = math_ops.sparse_segment_sqrt_n(
    493             embeddings, idx, segment_ids, name=name)
    494       else:
    495         assert False, "Unrecognized combiner"
    496 
    497     return embeddings
    498