Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Ops related to candidate sampling."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import embedding_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import nn
     27 from tensorflow.python.ops import nn_impl
     28 from tensorflow.python.ops import nn_ops
     29 
     30 
     31 def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
     32                    resampling_temperature, partition_strategy):
     33   """A helper function for rank_sampled_softmax_loss.
     34 
     35   This computes, for each i in `sampled_values`,
     36 
     37       log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
     38 
     39   where w_i, b_i are the weight and bias of the i-th class, respectively,
     40   and j ranges over the rows of `inputs`. For efficiency, we rearrange the
     41   computation to
     42 
     43       log(sum_j exp(w_i * (x_j / resampling_temperature))) +
     44           b_i / resampling_temperature.
     45 
     46   This translates to the following batched computation using tensorflow ops:
     47 
     48       reduce_logsumexp(matmul(embeddings,
     49                        transpose(inputs / resampling_temperature))) +
     50           biases / resampling_temperature
     51 
     52   The computation of the first term is colocated with the embeddings using
     53   `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
     54   term, not the bottleneck, is computed at the worker.
     55 
     56   Args:
     57     weights: From `rank_sampled_softmax_loss`.
     58     biases: From `rank_sampled_softmax_loss`.
     59     inputs: From `rank_sampled_softmax_loss`.
     60     sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
     61         `sampled_expected_count`) returned by a `*_candidate_sampler` function.
     62     num_resampled: An `int`. This many values are selected from
     63         `sampled_values` using the adaptive resampling algorithm. The caller
     64         must ensure that `num_resampled` is less than the size of
     65         `sampled_values`.
     66     resampling_temperature: A scalar `Tensor` with the temperature parameter
     67         for the adaptive resampling algorithm.
     68     partition_strategy: From `rank_sampled_softmax_loss`.
     69 
     70   Returns:
     71     A tuple of (`resampled_candidates`, `true_expected_count`,
     72         `resampled_expected_count`), similar to `sampled_values` but sampled
     73         down to `num_resampled` values.
     74   """
     75   # This code supports passing a Tensor for num_resampled, but since it is only
     76   # called with an int, that's what we specify in the arg list. If this
     77   # function is ever externalized, we should change the doc to support Tensor.
     78 
     79   sampled, true_expected_count, sampled_expected_count = sampled_values
     80 
     81   sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
     82   true_expected_count = array_ops.stop_gradient(true_expected_count)
     83   sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)
     84 
     85   reweighted_inputs = inputs / resampling_temperature
     86 
     87   def logsumexp_logit(embeddings):
     88     return math_ops.reduce_logsumexp(
     89         math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
     90         axis=1,
     91         keep_dims=False)
     92 
     93   # Calling this protected form of embedding_lookup allows co-locating
     94   # the logsumexp computation with the partitioned weights, which yields
     95   # a large speedup in practice.
     96   sampled_logits = embedding_ops._embedding_lookup_and_transform(  # pylint: disable=protected-access
     97       weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
     98   sampled_b = array_ops.reshape(
     99       embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
    100   sampled_logits += sampled_b / resampling_temperature
    101 
    102   _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
    103   resampled = array_ops.gather(sampled, indices=resampled_indices)
    104   resampled_expected_count = array_ops.gather(
    105       sampled_expected_count, indices=resampled_indices)
    106 
    107   return resampled, true_expected_count, resampled_expected_count
    108 
    109 
    110 def rank_sampled_softmax_loss(weights,
    111                               biases,
    112                               labels,
    113                               inputs,
    114                               num_sampled,
    115                               num_resampled,
    116                               num_classes,
    117                               num_true,
    118                               sampled_values,
    119                               resampling_temperature,
    120                               remove_accidental_hits,
    121                               partition_strategy,
    122                               name=None):
    123   """Computes softmax loss using rank-based adaptive resampling.
    124 
    125   This has been shown to improve rank loss after training compared to
    126   @{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some
    127   experimental results, please see: [TAPAS: Two-pass Approximate Adaptive
    128   Sampling for Softmax](https://arxiv.org/abs/1707.03073).
    129 
    130   Sampling follows two phases:
    131   * In the first phase, `num_sampled` classes are selected using
    132     @{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`.
    133     The logits are calculated on those sampled classes. This phases is
    134     similar to @{tf.nn.sampled_softmax_loss}.
    135   * In the second phase, the `num_resampled` classes with highest predicted
    136     probability are kept. Probabilities are
    137     `LogSumExp(logits / resampling_temperature)`, where the sum is over
    138     `inputs`.
    139 
    140   The `resampling_temperature` parameter controls the "adaptiveness" of the
    141   resampling. At lower temperatures, resampling is more adaptive because it
    142   picks more candidates close to the predicted classes. A common strategy is
    143   to decrease the temperature as training proceeds.
    144 
    145   See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and
    146   for typical default values for some of the parameters.
    147 
    148   This operation is for training only. It is generally an underestimate of
    149   the full softmax loss.
    150 
    151   A common use case is to use this method for training, and calculate the full
    152   softmax loss for evaluation or inference. In this case, you must set
    153   `partition_strategy="div"` for the two losses to be consistent, as in the
    154   following example:
    155 
    156   ```python
    157   if mode == "train":
    158     loss = rank_sampled_softmax_loss(
    159         weights=weights,
    160         biases=biases,
    161         labels=labels,
    162         inputs=inputs,
    163         ...,
    164         partition_strategy="div")
    165   elif mode == "eval":
    166     logits = tf.matmul(inputs, tf.transpose(weights))
    167     logits = tf.nn.bias_add(logits, biases)
    168     labels_one_hot = tf.one_hot(labels, n_classes)
    169     loss = tf.nn.softmax_cross_entropy_with_logits(
    170         labels=labels_one_hot,
    171         logits=logits)
    172   ```
    173 
    174   Args:
    175     weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`,
    176         or a list of `Tensor` objects whose concatenation along dimension 0
    177         has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
    178     biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`.
    179         The (possibly-sharded) class biases.
    180     labels: A `Tensor` of type `int64` and shape `[batch_size,
    181         num_true]`. The target classes. Note that this format differs from
    182         the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
    183     inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
    184         activations of the input network.
    185     num_sampled: An `int`. The number of classes to randomly sample per batch.
    186     num_resampled: An `int`. The number of classes to select from the
    187         `num_sampled` classes using the adaptive resampling algorithm. Must be
    188         less than `num_sampled`.
    189     num_classes: An `int`. The number of possible classes.
    190     num_true: An `int`.  The number of target classes per training example.
    191     sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
    192         `sampled_expected_count`) returned by a `*_candidate_sampler` function.
    193         If None, default to `nn.learned_unigram_candidate_sampler`.
    194     resampling_temperature: A scalar `Tensor` with the temperature parameter
    195         for the adaptive resampling algorithm.
    196     remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
    197         where a sampled class equals one of the target classes.
    198     partition_strategy: A string specifying the partitioning strategy, relevant
    199         if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
    200         See @{tf.nn.embedding_lookup} for more details.
    201     name: A name for the operation (optional).
    202 
    203   Returns:
    204     A `batch_size` 1-D tensor of per-example sampled softmax losses.
    205 
    206   Raises:
    207     ValueError: If `num_sampled <= num_resampled`.
    208   """
    209   if num_sampled > num_classes:
    210     raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})".
    211                      format(num_sampled, num_classes))
    212   if num_sampled <= num_resampled:
    213     raise ValueError("num_resampled ({}) must be less than num_sampled ({})".
    214                      format(num_resampled, num_sampled))
    215   if partition_strategy not in ("div", "mod"):
    216     raise ValueError(
    217         "unsupported partition_strategy ({})".format(partition_strategy))
    218   with ops.name_scope(name, "rank_sampled_softmax_loss", [
    219       weights, biases, labels, inputs, sampled_values, resampling_temperature
    220   ]) as name:
    221     if not sampled_values:
    222       sampled_values = nn.learned_unigram_candidate_sampler(
    223           true_classes=labels,
    224           num_true=num_true,
    225           num_sampled=num_sampled,
    226           unique=True,
    227           range_max=num_classes)
    228     # From sampled_values, select the top num_resampled values using the
    229     # adaptive rank resampling strategy.
    230     resampled_values = _rank_resample(weights, biases, inputs, sampled_values,
    231                                       num_resampled, resampling_temperature,
    232                                       partition_strategy)
    233     return nn.sampled_softmax_loss(
    234         weights=weights,
    235         biases=biases,
    236         labels=labels,
    237         inputs=inputs,
    238         num_sampled=num_resampled,
    239         num_classes=num_classes,
    240         num_true=num_true,
    241         sampled_values=resampled_values,
    242         remove_accidental_hits=remove_accidental_hits,
    243         partition_strategy=partition_strategy,
    244         name=name)
    245 
    246 
    247 def sampled_sparse_softmax_loss(weights,
    248                                 biases,
    249                                 labels,
    250                                 inputs,
    251                                 num_sampled,
    252                                 num_classes,
    253                                 sampled_values=None,
    254                                 remove_accidental_hits=True,
    255                                 partition_strategy="mod",
    256                                 name="sampled_sparse_softmax_loss"):
    257   """Computes and returns the sampled sparse softmax training loss.
    258 
    259   This is a faster way to train a softmax classifier over a huge number of
    260   classes.
    261 
    262   This operation is for training only.  It is generally an underestimate of
    263   the full softmax loss.
    264 
    265   A common use case is to use this method for training, and calculate the full
    266   softmax loss for evaluation or inference. In this case, you must set
    267   `partition_strategy="div"` for the two losses to be consistent, as in the
    268   following example:
    269 
    270   ```python
    271   if mode == "train":
    272     loss = tf.nn.sampled_sparse_softmax_loss(
    273         weights=weights,
    274         biases=biases,
    275         labels=labels,
    276         inputs=inputs,
    277         ...,
    278         partition_strategy="div")
    279   elif mode == "eval":
    280     logits = tf.matmul(inputs, tf.transpose(weights))
    281     logits = tf.nn.bias_add(logits, biases)
    282     loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
    283         labels=tf.squeeze(labels),
    284         logits=logits)
    285   ```
    286 
    287   See our [Candidate Sampling Algorithms Reference]
    288   (https://www.tensorflow.org/extras/candidate_sampling.pdf)
    289 
    290   Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
    291   ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
    292 
    293   Args:
    294     weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
    295         objects whose concatenation along dimension 0 has shape
    296         [num_classes, dim].  The (possibly-sharded) class embeddings.
    297     biases: A `Tensor` of shape `[num_classes]`.  The class biases.
    298     labels: A `Tensor` of type `int64` and shape `[batch_size, 1]`.
    299         The index of the single target class for each row of logits.  Note that
    300         this format differs from the `labels` argument of
    301         `nn.sparse_softmax_cross_entropy_with_logits`.
    302     inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
    303         activations of the input network.
    304     num_sampled: An `int`.  The number of classes to randomly sample per batch.
    305     num_classes: An `int`. The number of possible classes.
    306     sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
    307         `sampled_expected_count`) returned by a `*_candidate_sampler` function.
    308         (if None, we default to `log_uniform_candidate_sampler`)
    309     remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
    310         where a sampled class equals one of the target classes.  Default is
    311         True.
    312     partition_strategy: A string specifying the partitioning strategy, relevant
    313         if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
    314         Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
    315     name: A name for the operation (optional).
    316 
    317   Returns:
    318     A `batch_size` 1-D tensor of per-example sampled softmax losses.
    319 
    320   """
    321   logits, _ = nn_impl._compute_sampled_logits(
    322       weights=weights,
    323       biases=biases,
    324       labels=labels,
    325       inputs=inputs,
    326       num_sampled=num_sampled,
    327       num_classes=num_classes,
    328       num_true=1,
    329       sampled_values=sampled_values,
    330       subtract_log_q=True,
    331       remove_accidental_hits=remove_accidental_hits,
    332       partition_strategy=partition_strategy,
    333       name=name)
    334 
    335   # There is only one true label. _compute_sampled_logits puts the true logit
    336   # at index 0.
    337   labels = array_ops.zeros([array_ops.shape(logits)[0], 1], dtype=dtypes.int64)
    338 
    339   sampled_losses = nn_ops.sparse_softmax_cross_entropy_with_logits(
    340       labels=array_ops.squeeze(labels), logits=logits)
    341   # sampled_losses is a [batch_size] tensor.
    342   return sampled_losses
    343