Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Loss functions to be used by LayerCollection."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 
     23 import six
     24 
     25 from tensorflow.contrib.distributions.python.ops import onehot_categorical
     26 from tensorflow.python.framework import tensor_shape
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops.distributions import bernoulli
     30 from tensorflow.python.ops.distributions import categorical
     31 from tensorflow.python.ops.distributions import normal
     32 
     33 
     34 @six.add_metaclass(abc.ABCMeta)
     35 class LossFunction(object):
     36   """Abstract base class for loss functions.
     37 
     38   Note that unlike typical loss functions used in neural networks these are
     39   summed and not averaged across cases in the batch, since this is what the
     40   users of this class (FisherEstimator and MatrixVectorProductComputer) will
     41   be expecting. The implication of this is that you will may want to
     42   normalize things like Fisher-vector products by the batch size when you
     43   use this class.  It depends on the use case.
     44   """
     45 
     46   @abc.abstractproperty
     47   def targets(self):
     48     """The targets being predicted by the model.
     49 
     50     Returns:
     51       None or Tensor of appropriate shape for calling self._evaluate() on.
     52     """
     53     pass
     54 
     55   @abc.abstractproperty
     56   def inputs(self):
     57     """The inputs to the loss function (excluding the targets)."""
     58     pass
     59 
     60   @property
     61   def input_minibatches(self):
     62     """A `list` of inputs to the loss function, separated by minibatch.
     63 
     64     Typically there will be one minibatch per tower in a multi-tower setup.
     65     Returns a list consisting of `self.inputs` by default; `LossFunction`s
     66     supporting registering multiple minibatches should override this method.
     67 
     68     Returns:
     69       A `list` of `Tensor`s representing
     70     """
     71     return [self.inputs]
     72 
     73   @property
     74   def num_registered_minibatches(self):
     75     """Number of minibatches registered for this LossFunction.
     76 
     77     Typically equal to the number of towers in a multi-tower setup.
     78 
     79     Returns:
     80       An `int` representing the number of registered minibatches.
     81     """
     82     return len(self.input_minibatches)
     83 
     84   def evaluate(self):
     85     """Evaluate the loss function on the targets."""
     86     if self.targets is not None:
     87       # We treat the targets as "constant".  It's only the inputs that get
     88       # "back-propped" through.
     89       return self._evaluate(array_ops.stop_gradient(self.targets))
     90     else:
     91       raise Exception("Cannot evaluate losses with unspecified targets.")
     92 
     93   @abc.abstractmethod
     94   def _evaluate(self, targets):
     95     """Evaluates the negative log probability of the targets.
     96 
     97     Args:
     98       targets: Tensor that distribution can calculate log_prob() of.
     99 
    100     Returns:
    101       negative log probability of each target, summed across all targets.
    102     """
    103     pass
    104 
    105   @abc.abstractmethod
    106   def multiply_hessian(self, vector):
    107     """Right-multiply a vector by the Hessian.
    108 
    109     Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    110     of the loss function with respect to its inputs.
    111 
    112     Args:
    113       vector: The vector to multiply.  Must be the same shape(s) as the
    114         'inputs' property.
    115 
    116     Returns:
    117       The vector right-multiplied by the Hessian.  Will be of the same shape(s)
    118       as the 'inputs' property.
    119     """
    120     pass
    121 
    122   @abc.abstractmethod
    123   def multiply_hessian_factor(self, vector):
    124     """Right-multiply a vector by a factor B of the Hessian.
    125 
    126     Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    127     of the loss function with respect to its inputs.  Typically this will be
    128     block-diagonal across different cases in the batch, since the loss function
    129     is typically summed across cases.
    130 
    131     Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    132     but will agree with the one used in the other methods of this class.
    133 
    134     Args:
    135       vector: The vector to multiply.  Must be of the shape given by the
    136         'hessian_factor_inner_shape' property.
    137 
    138     Returns:
    139       The vector right-multiplied by B.  Will be of the same shape(s) as the
    140       'inputs' property.
    141     """
    142     pass
    143 
    144   @abc.abstractmethod
    145   def multiply_hessian_factor_transpose(self, vector):
    146     """Right-multiply a vector by the transpose of a factor B of the Hessian.
    147 
    148     Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    149     of the loss function with respect to its inputs.  Typically this will be
    150     block-diagonal across different cases in the batch, since the loss function
    151     is typically summed across cases.
    152 
    153     Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    154     but will agree with the one used in the other methods of this class.
    155 
    156     Args:
    157       vector: The vector to multiply.  Must be the same shape(s) as the
    158         'inputs' property.
    159 
    160     Returns:
    161       The vector right-multiplied by B^T.  Will be of the shape given by the
    162       'hessian_factor_inner_shape' property.
    163     """
    164     pass
    165 
    166   @abc.abstractmethod
    167   def multiply_hessian_factor_replicated_one_hot(self, index):
    168     """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
    169 
    170     Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    171     of the loss function with respect to its inputs.  Typically this will be
    172     block-diagonal across different cases in the batch, since the loss function
    173     is typically summed across cases.
    174 
    175     A 'replicated-one-hot' vector means a tensor which, for each slice along the
    176     batch dimension (assumed to be dimension 0), is 1.0 in the entry
    177     corresponding to the given index and 0 elsewhere.
    178 
    179     Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    180     but will agree with the one used in the other methods of this class.
    181 
    182     Args:
    183       index: A tuple representing in the index of the entry in each slice that
    184         is 1.0. Note that len(index) must be equal to the number of elements
    185         of the 'hessian_factor_inner_shape' tensor minus one.
    186 
    187     Returns:
    188       The vector right-multiplied by B^T. Will be of the same shape(s) as the
    189       'inputs' property.
    190     """
    191     pass
    192 
    193   @abc.abstractproperty
    194   def hessian_factor_inner_shape(self):
    195     """The shape of the tensor returned by multiply_hessian_factor."""
    196     pass
    197 
    198   @abc.abstractproperty
    199   def hessian_factor_inner_static_shape(self):
    200     """Static version of hessian_factor_inner_shape."""
    201     pass
    202 
    203 
    204 @six.add_metaclass(abc.ABCMeta)
    205 class NegativeLogProbLoss(LossFunction):
    206   """Abstract base class for loss functions that are negative log probs."""
    207 
    208   def __init__(self, seed=None):
    209     self._default_seed = seed
    210     super(NegativeLogProbLoss, self).__init__()
    211 
    212   @property
    213   def inputs(self):
    214     return self.params
    215 
    216   @abc.abstractproperty
    217   def params(self):
    218     """Parameters to the underlying distribution."""
    219     pass
    220 
    221   @abc.abstractmethod
    222   def multiply_fisher(self, vector):
    223     """Right-multiply a vector by the Fisher.
    224 
    225     Args:
    226       vector: The vector to multiply.  Must be the same shape(s) as the
    227         'inputs' property.
    228 
    229     Returns:
    230       The vector right-multiplied by the Fisher.  Will be of the same shape(s)
    231       as the 'inputs' property.
    232     """
    233     pass
    234 
    235   @abc.abstractmethod
    236   def multiply_fisher_factor(self, vector):
    237     """Right-multiply a vector by a factor B of the Fisher.
    238 
    239     Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    240     product of gradients) with respect to the parameters of the underlying
    241     probability distribtion (whose log-prob defines the loss). Typically this
    242     will be block-diagonal across different cases in the batch, since the
    243     distribution is usually (but not always) conditionally iid across different
    244     cases.
    245 
    246     Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
    247     but will agree with the one used in the other methods of this class.
    248 
    249     Args:
    250       vector: The vector to multiply.  Must be of the shape given by the
    251         'fisher_factor_inner_shape' property.
    252 
    253     Returns:
    254       The vector right-multiplied by B. Will be of the same shape(s) as the
    255       'inputs' property.
    256     """
    257     pass
    258 
    259   @abc.abstractmethod
    260   def multiply_fisher_factor_transpose(self, vector):
    261     """Right-multiply a vector by the transpose of a factor B of the Fisher.
    262 
    263     Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    264     product of gradients) with respect to the parameters of the underlying
    265     probability distribtion (whose log-prob defines the loss). Typically this
    266     will be block-diagonal across different cases in the batch, since the
    267     distribution is usually (but not always) conditionally iid across different
    268     cases.
    269 
    270     Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
    271     but will agree with the one used in the other methods of this class.
    272 
    273     Args:
    274       vector: The vector to multiply.  Must be the same shape(s) as the
    275         'inputs' property.
    276 
    277     Returns:
    278       The vector right-multiplied by B^T.  Will be of the shape given by the
    279       'fisher_factor_inner_shape' property.
    280     """
    281     pass
    282 
    283   @abc.abstractmethod
    284   def multiply_fisher_factor_replicated_one_hot(self, index):
    285     """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
    286 
    287     Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    288     product of gradients) with respect to the parameters of the underlying
    289     probability distribtion (whose log-prob defines the loss). Typically this
    290     will be block-diagonal across different cases in the batch, since the
    291     distribution is usually (but not always) conditionally iid across different
    292     cases.
    293 
    294     A 'replicated-one-hot' vector means a tensor which, for each slice along the
    295     batch dimension (assumed to be dimension 0), is 1.0 in the entry
    296     corresponding to the given index and 0 elsewhere.
    297 
    298     Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
    299     but will agree with the one used in the other methods of this class.
    300 
    301     Args:
    302       index: A tuple representing in the index of the entry in each slice that
    303         is 1.0. Note that len(index) must be equal to the number of elements
    304         of the 'fisher_factor_inner_shape' tensor minus one.
    305 
    306     Returns:
    307       The vector right-multiplied by B. Will be of the same shape(s) as the
    308       'inputs' property.
    309     """
    310     pass
    311 
    312   @abc.abstractproperty
    313   def fisher_factor_inner_shape(self):
    314     """The shape of the tensor returned by multiply_fisher_factor."""
    315     pass
    316 
    317   @abc.abstractproperty
    318   def fisher_factor_inner_static_shape(self):
    319     """Static version of fisher_factor_inner_shape."""
    320     pass
    321 
    322   @abc.abstractmethod
    323   def sample(self, seed):
    324     """Sample 'targets' from the underlying distribution."""
    325     pass
    326 
    327   def evaluate_on_sample(self, seed=None):
    328     """Evaluates the log probability on a random sample.
    329 
    330     Args:
    331       seed: int or None. Random seed for this draw from the distribution.
    332 
    333     Returns:
    334       Log probability of sampled targets, summed across examples.
    335     """
    336     if seed is None:
    337       seed = self._default_seed
    338     # We treat the targets as "constant".  It's only the inputs that get
    339     # "back-propped" through.
    340     return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
    341 
    342 
    343 # TODO(jamesmartens): should this just inherit from object to avoid "diamond"
    344 # inheritance, or is there a better way?
    345 class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
    346   """Base class for neg log prob losses whose inputs are 'natural' parameters.
    347 
    348   Note that the Hessian and Fisher for natural parameters of exponential-
    349   family models are the same, hence the purpose of this class.
    350   See here: https://arxiv.org/abs/1412.1193
    351 
    352   'Natural parameters' are defined for exponential-family models. See for
    353   example: https://en.wikipedia.org/wiki/Exponential_family
    354   """
    355 
    356   def multiply_hessian(self, vector):
    357     return self.multiply_fisher(vector)
    358 
    359   def multiply_hessian_factor(self, vector):
    360     return self.multiply_fisher_factor(vector)
    361 
    362   def multiply_hessian_factor_transpose(self, vector):
    363     return self.multiply_fisher_factor_transpose(vector)
    364 
    365   def multiply_hessian_factor_replicated_one_hot(self, index):
    366     return self.multiply_fisher_factor_replicated_one_hot(index)
    367 
    368   @property
    369   def hessian_factor_inner_shape(self):
    370     return self.fisher_factor_inner_shape
    371 
    372   @property
    373   def hessian_factor_inner_static_shape(self):
    374     return self.fisher_factor_inner_shape
    375 
    376 
    377 class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
    378   """Base class for neg log prob losses that use the TF Distribution classes."""
    379 
    380   def __init__(self, seed=None):
    381     super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
    382 
    383   @abc.abstractproperty
    384   def dist(self):
    385     """The underlying tf.distributions.Distribution."""
    386     pass
    387 
    388   def _evaluate(self, targets):
    389     return -math_ops.reduce_sum(self.dist.log_prob(targets))
    390 
    391   def sample(self, seed):
    392     return self.dist.sample(seed=seed)
    393 
    394 
    395 class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
    396                                     NaturalParamsNegativeLogProbLoss):
    397   """Neg log prob loss for a normal distribution parameterized by a mean vector.
    398 
    399 
    400   Note that the covariance is treated as a constant 'var' times the identity.
    401   Also note that the Fisher for such a normal distribution with respect the mean
    402   parameter is given by:
    403 
    404      F = (1/var) * I
    405 
    406   See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
    407   """
    408 
    409   def __init__(self, mean, var=0.5, targets=None, seed=None):
    410     self._mean = mean
    411     self._var = var
    412     self._targets = targets
    413     super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
    414 
    415   @property
    416   def targets(self):
    417     return self._targets
    418 
    419   @property
    420   def dist(self):
    421     return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
    422 
    423   @property
    424   def params(self):
    425     return self._mean
    426 
    427   def multiply_fisher(self, vector):
    428     return (1. / self._var) * vector
    429 
    430   def multiply_fisher_factor(self, vector):
    431     return self._var**-0.5 * vector
    432 
    433   def multiply_fisher_factor_transpose(self, vector):
    434     return self.multiply_fisher_factor(vector)  # it's symmetric in this case
    435 
    436   def multiply_fisher_factor_replicated_one_hot(self, index):
    437     assert len(index) == 1, "Length of index was {}".format(len(index))
    438     ones_slice = array_ops.expand_dims(
    439         array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
    440         axis=-1)
    441     output_slice = self._var**-0.5 * ones_slice
    442     return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
    443                                  index[0])
    444 
    445   @property
    446   def fisher_factor_inner_shape(self):
    447     return array_ops.shape(self._mean)
    448 
    449   @property
    450   def fisher_factor_inner_static_shape(self):
    451     return self._mean.shape
    452 
    453 
    454 class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
    455   """Negative log prob loss for a normal distribution with mean and variance.
    456 
    457   This class parameterizes a multivariate normal distribution with n independent
    458   dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
    459   assume the variance is held constant. The Fisher Information for n = 1
    460   is given by,
    461 
    462   F = [[1 / variance,                0],
    463        [           0, 0.5 / variance^2]]
    464 
    465   where the parameters of the distribution are concatenated into a single
    466   vector as [mean, variance]. For n > 1, the mean parameter vector is
    467   concatenated with the variance parameter vector.
    468 
    469   See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
    470   """
    471 
    472   def __init__(self, mean, variance, targets=None, seed=None):
    473     assert len(mean.shape) == 2, "Expect 2D mean tensor."
    474     assert len(variance.shape) == 2, "Expect 2D variance tensor."
    475     self._mean = mean
    476     self._variance = variance
    477     self._scale = math_ops.sqrt(variance)
    478     self._targets = targets
    479     super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
    480 
    481   @property
    482   def targets(self):
    483     return self._targets
    484 
    485   @property
    486   def dist(self):
    487     return normal.Normal(loc=self._mean, scale=self._scale)
    488 
    489   @property
    490   def params(self):
    491     return self._mean, self._variance
    492 
    493   def _concat(self, mean, variance):
    494     return array_ops.concat([mean, variance], axis=-1)
    495 
    496   def _split(self, params):
    497     return array_ops.split(params, 2, axis=-1)
    498 
    499   @property
    500   def _fisher_mean(self):
    501     return 1. / self._variance
    502 
    503   @property
    504   def _fisher_mean_factor(self):
    505     return 1. / self._scale
    506 
    507   @property
    508   def _fisher_var(self):
    509     return 1. / (2 * math_ops.square(self._variance))
    510 
    511   @property
    512   def _fisher_var_factor(self):
    513     return 1. / (math_ops.sqrt(2.) * self._variance)
    514 
    515   def multiply_fisher(self, vecs):
    516     mean_vec, var_vec = vecs
    517     return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
    518 
    519   def multiply_fisher_factor(self, vecs):
    520     mean_vec, var_vec = self._split(vecs)
    521     return (self._fisher_mean_factor * mean_vec,
    522             self._fisher_var_factor * var_vec)
    523 
    524   def multiply_fisher_factor_transpose(self, vecs):
    525     mean_vec, var_vec = vecs
    526     return self._concat(self._fisher_mean_factor * mean_vec,
    527                         self._fisher_var_factor * var_vec)
    528 
    529   def multiply_fisher_factor_replicated_one_hot(self, index):
    530     assert len(index) == 1, "Length of index was {}".format(len(index))
    531     index = index[0]
    532 
    533     if index < int(self._mean.shape[-1]):
    534       # Index corresponds to mean parameter.
    535       mean_slice = self._fisher_mean_factor[:, index]
    536       mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
    537       mean_output = insert_slice_in_zeros(mean_slice, 1, int(
    538           self._mean.shape[1]), index)
    539       var_output = array_ops.zeros_like(mean_output)
    540     else:
    541       index -= int(self._mean.shape[-1])
    542       # Index corresponds to variance parameter.
    543       var_slice = self._fisher_var_factor[:, index]
    544       var_slice = array_ops.expand_dims(var_slice, axis=-1)
    545       var_output = insert_slice_in_zeros(var_slice, 1,
    546                                          int(self._variance.shape[1]), index)
    547       mean_output = array_ops.zeros_like(var_output)
    548 
    549     return mean_output, var_output
    550 
    551   @property
    552   def fisher_factor_inner_shape(self):
    553     return array_ops.concat(
    554         [
    555             array_ops.shape(self._mean)[:-1],
    556             2 * array_ops.shape(self._mean)[-1:]
    557         ],
    558         axis=0)
    559 
    560   @property
    561   def fisher_factor_inner_static_shape(self):
    562     shape = self._mean.shape.as_list()
    563     return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
    564 
    565   def multiply_hessian(self, vector):
    566     raise NotImplementedError()
    567 
    568   def multiply_hessian_factor(self, vector):
    569     raise NotImplementedError()
    570 
    571   def multiply_hessian_factor_transpose(self, vector):
    572     raise NotImplementedError()
    573 
    574   def multiply_hessian_factor_replicated_one_hot(self, index):
    575     raise NotImplementedError()
    576 
    577   @property
    578   def hessian_factor_inner_shape(self):
    579     raise NotImplementedError()
    580 
    581   @property
    582   def hessian_factor_inner_static_shape(self):
    583     raise NotImplementedError()
    584 
    585 
    586 class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
    587                                            NaturalParamsNegativeLogProbLoss):
    588   """Neg log prob loss for a categorical distribution parameterized by logits.
    589 
    590 
    591   Note that the Fisher (for a single case) of a categorical distribution, with
    592   respect to the natural parameters (i.e. the logits), is given by:
    593 
    594   F = diag(p) - p*p^T
    595 
    596   where p = softmax(logits).  F can be factorized as F = B * B^T where
    597 
    598   B = diag(q) - p*q^T
    599 
    600   where q is the entry-wise square root of p. This is easy to verify using the
    601   fact that q^T*q = 1.
    602   """
    603 
    604   def __init__(self, logits, targets=None, seed=None):
    605     """Instantiates a CategoricalLogitsNegativeLogProbLoss.
    606 
    607     Args:
    608       logits: Tensor of shape [batch_size, output_size]. Parameters for
    609         underlying distribution.
    610       targets: None or Tensor of shape [output_size]. Each elements contains an
    611         index in [0, output_size).
    612       seed: int or None. Default random seed when sampling.
    613     """
    614     self._logits_components = []
    615     self._targets_components = []
    616     self.register_additional_minibatch(logits, targets=targets)
    617     super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
    618 
    619   def register_additional_minibatch(self, logits, targets=None):
    620     """Register an additiona minibatch's worth of parameters.
    621 
    622     Args:
    623       logits: Tensor of shape [batch_size, output_size]. Parameters for
    624         underlying distribution.
    625       targets: None or Tensor of shape [batch_size, output_size].  Each row must
    626         be a one-hot vector.
    627     """
    628     self._logits_components.append(logits)
    629     self._targets_components.append(targets)
    630 
    631   @property
    632   def _logits(self):
    633     return array_ops.concat(self._logits_components, axis=0)
    634 
    635   @property
    636   def input_minibatches(self):
    637     return self._logits_components
    638 
    639   @property
    640   def targets(self):
    641     if all(target is None for target in self._targets_components):
    642       return None
    643     return array_ops.concat(self._targets_components, axis=0)
    644 
    645   @property
    646   def dist(self):
    647     return categorical.Categorical(logits=self._logits)
    648 
    649   @property
    650   def _probs(self):
    651     return self.dist.probs
    652 
    653   @property
    654   def _sqrt_probs(self):
    655     return math_ops.sqrt(self._probs)
    656 
    657   @property
    658   def params(self):
    659     return self._logits
    660 
    661   def multiply_fisher(self, vector):
    662     probs = self._probs
    663     return vector * probs - probs * math_ops.reduce_sum(
    664         vector * probs, axis=-1, keep_dims=True)
    665 
    666   def multiply_fisher_factor(self, vector):
    667     probs = self._probs
    668     sqrt_probs = self._sqrt_probs
    669     return sqrt_probs * vector - probs * math_ops.reduce_sum(
    670         sqrt_probs * vector, axis=-1, keep_dims=True)
    671 
    672   def multiply_fisher_factor_transpose(self, vector):
    673     probs = self._probs
    674     sqrt_probs = self._sqrt_probs
    675     return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
    676         probs * vector, axis=-1, keep_dims=True)
    677 
    678   def multiply_fisher_factor_replicated_one_hot(self, index):
    679     assert len(index) == 1, "Length of index was {}".format(len(index))
    680     probs = self._probs
    681     sqrt_probs = self._sqrt_probs
    682     sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
    683     padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
    684                                          int(sqrt_probs.shape[1]), index[0])
    685     return padded_slice - probs * sqrt_probs_slice
    686 
    687   @property
    688   def fisher_factor_inner_shape(self):
    689     return array_ops.shape(self._logits)
    690 
    691   @property
    692   def fisher_factor_inner_static_shape(self):
    693     return self._logits.shape
    694 
    695 
    696 class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
    697                                         NaturalParamsNegativeLogProbLoss):
    698   """Neg log prob loss for multiple Bernoulli distributions param'd by logits.
    699 
    700   Represents N independent Bernoulli distributions where N = len(logits). Its
    701   Fisher Information matrix is given by,
    702 
    703   F = diag(p * (1-p))
    704   p = sigmoid(logits)
    705 
    706   As F is diagonal with positive entries, its factor B is,
    707 
    708   B = diag(sqrt(p * (1-p)))
    709   """
    710 
    711   def __init__(self, logits, targets=None, seed=None):
    712     self._logits = logits
    713     self._targets = targets
    714     super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
    715 
    716   @property
    717   def targets(self):
    718     return self._targets
    719 
    720   @property
    721   def dist(self):
    722     return bernoulli.Bernoulli(logits=self._logits)
    723 
    724   @property
    725   def _probs(self):
    726     return self.dist.probs
    727 
    728   @property
    729   def params(self):
    730     return self._logits
    731 
    732   def multiply_fisher(self, vector):
    733     return self._probs * (1 - self._probs) * vector
    734 
    735   def multiply_fisher_factor(self, vector):
    736     return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
    737 
    738   def multiply_fisher_factor_transpose(self, vector):
    739     return self.multiply_fisher_factor(vector)  # it's symmetric in this case
    740 
    741   def multiply_fisher_factor_replicated_one_hot(self, index):
    742     assert len(index) == 1, "Length of index was {}".format(len(index))
    743     probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
    744     output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
    745     return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
    746                                  index[0])
    747 
    748   @property
    749   def fisher_factor_inner_shape(self):
    750     return array_ops.shape(self._logits)
    751 
    752   @property
    753   def fisher_factor_inner_static_shape(self):
    754     return self._logits.shape
    755 
    756 
    757 def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
    758   """Inserts slice into a larger tensor of zeros.
    759 
    760   Forms a new tensor which is the same shape as slice_to_insert, except that
    761   the dimension given by 'dim' is expanded to the size given by 'dim_size'.
    762   'position' determines the position (index) at which to insert the slice within
    763   that dimension.
    764 
    765   Assumes slice_to_insert.shape[dim] = 1.
    766 
    767   Args:
    768     slice_to_insert: The slice to insert.
    769     dim: The dimension which to expand with zeros.
    770     dim_size: The new size of the 'dim' dimension.
    771     position: The position of 'slice_to_insert' in the new tensor.
    772 
    773   Returns:
    774     The new tensor.
    775 
    776   Raises:
    777     ValueError: If the slice's shape at the given dim is not 1.
    778   """
    779   slice_shape = slice_to_insert.shape
    780   if slice_shape[dim] != 1:
    781     raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
    782                      "was {}".format(dim, slice_to_insert.shape[dim]))
    783 
    784   before = [0] * int(len(slice_shape))
    785   after = before[:]
    786   before[dim] = position
    787   after[dim] = dim_size - position - 1
    788 
    789   return array_ops.pad(slice_to_insert, list(zip(before, after)))
    790 
    791 
    792 class OnehotCategoricalLogitsNegativeLogProbLoss(
    793     CategoricalLogitsNegativeLogProbLoss):
    794   """Neg log prob loss for a categorical distribution with onehot targets.
    795 
    796   Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
    797   distribution is OneHotCategorical as opposed to Categorical.
    798   """
    799 
    800   @property
    801   def dist(self):
    802     return onehot_categorical.OneHotCategorical(logits=self._logits)
    803