1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Loss functions to be used by LayerCollection.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import abc 22 23 import six 24 25 from tensorflow.contrib.distributions.python.ops import onehot_categorical 26 from tensorflow.python.framework import tensor_shape 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.ops.distributions import bernoulli 30 from tensorflow.python.ops.distributions import categorical 31 from tensorflow.python.ops.distributions import normal 32 33 34 @six.add_metaclass(abc.ABCMeta) 35 class LossFunction(object): 36 """Abstract base class for loss functions. 37 38 Note that unlike typical loss functions used in neural networks these are 39 summed and not averaged across cases in the batch, since this is what the 40 users of this class (FisherEstimator and MatrixVectorProductComputer) will 41 be expecting. The implication of this is that you will may want to 42 normalize things like Fisher-vector products by the batch size when you 43 use this class. It depends on the use case. 44 """ 45 46 @abc.abstractproperty 47 def targets(self): 48 """The targets being predicted by the model. 49 50 Returns: 51 None or Tensor of appropriate shape for calling self._evaluate() on. 52 """ 53 pass 54 55 @abc.abstractproperty 56 def inputs(self): 57 """The inputs to the loss function (excluding the targets).""" 58 pass 59 60 @property 61 def input_minibatches(self): 62 """A `list` of inputs to the loss function, separated by minibatch. 63 64 Typically there will be one minibatch per tower in a multi-tower setup. 65 Returns a list consisting of `self.inputs` by default; `LossFunction`s 66 supporting registering multiple minibatches should override this method. 67 68 Returns: 69 A `list` of `Tensor`s representing 70 """ 71 return [self.inputs] 72 73 @property 74 def num_registered_minibatches(self): 75 """Number of minibatches registered for this LossFunction. 76 77 Typically equal to the number of towers in a multi-tower setup. 78 79 Returns: 80 An `int` representing the number of registered minibatches. 81 """ 82 return len(self.input_minibatches) 83 84 def evaluate(self): 85 """Evaluate the loss function on the targets.""" 86 if self.targets is not None: 87 # We treat the targets as "constant". It's only the inputs that get 88 # "back-propped" through. 89 return self._evaluate(array_ops.stop_gradient(self.targets)) 90 else: 91 raise Exception("Cannot evaluate losses with unspecified targets.") 92 93 @abc.abstractmethod 94 def _evaluate(self, targets): 95 """Evaluates the negative log probability of the targets. 96 97 Args: 98 targets: Tensor that distribution can calculate log_prob() of. 99 100 Returns: 101 negative log probability of each target, summed across all targets. 102 """ 103 pass 104 105 @abc.abstractmethod 106 def multiply_hessian(self, vector): 107 """Right-multiply a vector by the Hessian. 108 109 Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 110 of the loss function with respect to its inputs. 111 112 Args: 113 vector: The vector to multiply. Must be the same shape(s) as the 114 'inputs' property. 115 116 Returns: 117 The vector right-multiplied by the Hessian. Will be of the same shape(s) 118 as the 'inputs' property. 119 """ 120 pass 121 122 @abc.abstractmethod 123 def multiply_hessian_factor(self, vector): 124 """Right-multiply a vector by a factor B of the Hessian. 125 126 Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 127 of the loss function with respect to its inputs. Typically this will be 128 block-diagonal across different cases in the batch, since the loss function 129 is typically summed across cases. 130 131 Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 132 but will agree with the one used in the other methods of this class. 133 134 Args: 135 vector: The vector to multiply. Must be of the shape given by the 136 'hessian_factor_inner_shape' property. 137 138 Returns: 139 The vector right-multiplied by B. Will be of the same shape(s) as the 140 'inputs' property. 141 """ 142 pass 143 144 @abc.abstractmethod 145 def multiply_hessian_factor_transpose(self, vector): 146 """Right-multiply a vector by the transpose of a factor B of the Hessian. 147 148 Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 149 of the loss function with respect to its inputs. Typically this will be 150 block-diagonal across different cases in the batch, since the loss function 151 is typically summed across cases. 152 153 Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 154 but will agree with the one used in the other methods of this class. 155 156 Args: 157 vector: The vector to multiply. Must be the same shape(s) as the 158 'inputs' property. 159 160 Returns: 161 The vector right-multiplied by B^T. Will be of the shape given by the 162 'hessian_factor_inner_shape' property. 163 """ 164 pass 165 166 @abc.abstractmethod 167 def multiply_hessian_factor_replicated_one_hot(self, index): 168 """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. 169 170 Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) 171 of the loss function with respect to its inputs. Typically this will be 172 block-diagonal across different cases in the batch, since the loss function 173 is typically summed across cases. 174 175 A 'replicated-one-hot' vector means a tensor which, for each slice along the 176 batch dimension (assumed to be dimension 0), is 1.0 in the entry 177 corresponding to the given index and 0 elsewhere. 178 179 Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, 180 but will agree with the one used in the other methods of this class. 181 182 Args: 183 index: A tuple representing in the index of the entry in each slice that 184 is 1.0. Note that len(index) must be equal to the number of elements 185 of the 'hessian_factor_inner_shape' tensor minus one. 186 187 Returns: 188 The vector right-multiplied by B^T. Will be of the same shape(s) as the 189 'inputs' property. 190 """ 191 pass 192 193 @abc.abstractproperty 194 def hessian_factor_inner_shape(self): 195 """The shape of the tensor returned by multiply_hessian_factor.""" 196 pass 197 198 @abc.abstractproperty 199 def hessian_factor_inner_static_shape(self): 200 """Static version of hessian_factor_inner_shape.""" 201 pass 202 203 204 @six.add_metaclass(abc.ABCMeta) 205 class NegativeLogProbLoss(LossFunction): 206 """Abstract base class for loss functions that are negative log probs.""" 207 208 def __init__(self, seed=None): 209 self._default_seed = seed 210 super(NegativeLogProbLoss, self).__init__() 211 212 @property 213 def inputs(self): 214 return self.params 215 216 @abc.abstractproperty 217 def params(self): 218 """Parameters to the underlying distribution.""" 219 pass 220 221 @abc.abstractmethod 222 def multiply_fisher(self, vector): 223 """Right-multiply a vector by the Fisher. 224 225 Args: 226 vector: The vector to multiply. Must be the same shape(s) as the 227 'inputs' property. 228 229 Returns: 230 The vector right-multiplied by the Fisher. Will be of the same shape(s) 231 as the 'inputs' property. 232 """ 233 pass 234 235 @abc.abstractmethod 236 def multiply_fisher_factor(self, vector): 237 """Right-multiply a vector by a factor B of the Fisher. 238 239 Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 240 product of gradients) with respect to the parameters of the underlying 241 probability distribtion (whose log-prob defines the loss). Typically this 242 will be block-diagonal across different cases in the batch, since the 243 distribution is usually (but not always) conditionally iid across different 244 cases. 245 246 Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, 247 but will agree with the one used in the other methods of this class. 248 249 Args: 250 vector: The vector to multiply. Must be of the shape given by the 251 'fisher_factor_inner_shape' property. 252 253 Returns: 254 The vector right-multiplied by B. Will be of the same shape(s) as the 255 'inputs' property. 256 """ 257 pass 258 259 @abc.abstractmethod 260 def multiply_fisher_factor_transpose(self, vector): 261 """Right-multiply a vector by the transpose of a factor B of the Fisher. 262 263 Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 264 product of gradients) with respect to the parameters of the underlying 265 probability distribtion (whose log-prob defines the loss). Typically this 266 will be block-diagonal across different cases in the batch, since the 267 distribution is usually (but not always) conditionally iid across different 268 cases. 269 270 Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, 271 but will agree with the one used in the other methods of this class. 272 273 Args: 274 vector: The vector to multiply. Must be the same shape(s) as the 275 'inputs' property. 276 277 Returns: 278 The vector right-multiplied by B^T. Will be of the shape given by the 279 'fisher_factor_inner_shape' property. 280 """ 281 pass 282 283 @abc.abstractmethod 284 def multiply_fisher_factor_replicated_one_hot(self, index): 285 """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. 286 287 Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- 288 product of gradients) with respect to the parameters of the underlying 289 probability distribtion (whose log-prob defines the loss). Typically this 290 will be block-diagonal across different cases in the batch, since the 291 distribution is usually (but not always) conditionally iid across different 292 cases. 293 294 A 'replicated-one-hot' vector means a tensor which, for each slice along the 295 batch dimension (assumed to be dimension 0), is 1.0 in the entry 296 corresponding to the given index and 0 elsewhere. 297 298 Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, 299 but will agree with the one used in the other methods of this class. 300 301 Args: 302 index: A tuple representing in the index of the entry in each slice that 303 is 1.0. Note that len(index) must be equal to the number of elements 304 of the 'fisher_factor_inner_shape' tensor minus one. 305 306 Returns: 307 The vector right-multiplied by B. Will be of the same shape(s) as the 308 'inputs' property. 309 """ 310 pass 311 312 @abc.abstractproperty 313 def fisher_factor_inner_shape(self): 314 """The shape of the tensor returned by multiply_fisher_factor.""" 315 pass 316 317 @abc.abstractproperty 318 def fisher_factor_inner_static_shape(self): 319 """Static version of fisher_factor_inner_shape.""" 320 pass 321 322 @abc.abstractmethod 323 def sample(self, seed): 324 """Sample 'targets' from the underlying distribution.""" 325 pass 326 327 def evaluate_on_sample(self, seed=None): 328 """Evaluates the log probability on a random sample. 329 330 Args: 331 seed: int or None. Random seed for this draw from the distribution. 332 333 Returns: 334 Log probability of sampled targets, summed across examples. 335 """ 336 if seed is None: 337 seed = self._default_seed 338 # We treat the targets as "constant". It's only the inputs that get 339 # "back-propped" through. 340 return self._evaluate(array_ops.stop_gradient(self.sample(seed))) 341 342 343 # TODO(jamesmartens): should this just inherit from object to avoid "diamond" 344 # inheritance, or is there a better way? 345 class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): 346 """Base class for neg log prob losses whose inputs are 'natural' parameters. 347 348 Note that the Hessian and Fisher for natural parameters of exponential- 349 family models are the same, hence the purpose of this class. 350 See here: https://arxiv.org/abs/1412.1193 351 352 'Natural parameters' are defined for exponential-family models. See for 353 example: https://en.wikipedia.org/wiki/Exponential_family 354 """ 355 356 def multiply_hessian(self, vector): 357 return self.multiply_fisher(vector) 358 359 def multiply_hessian_factor(self, vector): 360 return self.multiply_fisher_factor(vector) 361 362 def multiply_hessian_factor_transpose(self, vector): 363 return self.multiply_fisher_factor_transpose(vector) 364 365 def multiply_hessian_factor_replicated_one_hot(self, index): 366 return self.multiply_fisher_factor_replicated_one_hot(index) 367 368 @property 369 def hessian_factor_inner_shape(self): 370 return self.fisher_factor_inner_shape 371 372 @property 373 def hessian_factor_inner_static_shape(self): 374 return self.fisher_factor_inner_shape 375 376 377 class DistributionNegativeLogProbLoss(NegativeLogProbLoss): 378 """Base class for neg log prob losses that use the TF Distribution classes.""" 379 380 def __init__(self, seed=None): 381 super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) 382 383 @abc.abstractproperty 384 def dist(self): 385 """The underlying tf.distributions.Distribution.""" 386 pass 387 388 def _evaluate(self, targets): 389 return -math_ops.reduce_sum(self.dist.log_prob(targets)) 390 391 def sample(self, seed): 392 return self.dist.sample(seed=seed) 393 394 395 class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, 396 NaturalParamsNegativeLogProbLoss): 397 """Neg log prob loss for a normal distribution parameterized by a mean vector. 398 399 400 Note that the covariance is treated as a constant 'var' times the identity. 401 Also note that the Fisher for such a normal distribution with respect the mean 402 parameter is given by: 403 404 F = (1/var) * I 405 406 See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. 407 """ 408 409 def __init__(self, mean, var=0.5, targets=None, seed=None): 410 self._mean = mean 411 self._var = var 412 self._targets = targets 413 super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) 414 415 @property 416 def targets(self): 417 return self._targets 418 419 @property 420 def dist(self): 421 return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) 422 423 @property 424 def params(self): 425 return self._mean 426 427 def multiply_fisher(self, vector): 428 return (1. / self._var) * vector 429 430 def multiply_fisher_factor(self, vector): 431 return self._var**-0.5 * vector 432 433 def multiply_fisher_factor_transpose(self, vector): 434 return self.multiply_fisher_factor(vector) # it's symmetric in this case 435 436 def multiply_fisher_factor_replicated_one_hot(self, index): 437 assert len(index) == 1, "Length of index was {}".format(len(index)) 438 ones_slice = array_ops.expand_dims( 439 array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), 440 axis=-1) 441 output_slice = self._var**-0.5 * ones_slice 442 return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), 443 index[0]) 444 445 @property 446 def fisher_factor_inner_shape(self): 447 return array_ops.shape(self._mean) 448 449 @property 450 def fisher_factor_inner_static_shape(self): 451 return self._mean.shape 452 453 454 class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): 455 """Negative log prob loss for a normal distribution with mean and variance. 456 457 This class parameterizes a multivariate normal distribution with n independent 458 dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not 459 assume the variance is held constant. The Fisher Information for n = 1 460 is given by, 461 462 F = [[1 / variance, 0], 463 [ 0, 0.5 / variance^2]] 464 465 where the parameters of the distribution are concatenated into a single 466 vector as [mean, variance]. For n > 1, the mean parameter vector is 467 concatenated with the variance parameter vector. 468 469 See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. 470 """ 471 472 def __init__(self, mean, variance, targets=None, seed=None): 473 assert len(mean.shape) == 2, "Expect 2D mean tensor." 474 assert len(variance.shape) == 2, "Expect 2D variance tensor." 475 self._mean = mean 476 self._variance = variance 477 self._scale = math_ops.sqrt(variance) 478 self._targets = targets 479 super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) 480 481 @property 482 def targets(self): 483 return self._targets 484 485 @property 486 def dist(self): 487 return normal.Normal(loc=self._mean, scale=self._scale) 488 489 @property 490 def params(self): 491 return self._mean, self._variance 492 493 def _concat(self, mean, variance): 494 return array_ops.concat([mean, variance], axis=-1) 495 496 def _split(self, params): 497 return array_ops.split(params, 2, axis=-1) 498 499 @property 500 def _fisher_mean(self): 501 return 1. / self._variance 502 503 @property 504 def _fisher_mean_factor(self): 505 return 1. / self._scale 506 507 @property 508 def _fisher_var(self): 509 return 1. / (2 * math_ops.square(self._variance)) 510 511 @property 512 def _fisher_var_factor(self): 513 return 1. / (math_ops.sqrt(2.) * self._variance) 514 515 def multiply_fisher(self, vecs): 516 mean_vec, var_vec = vecs 517 return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) 518 519 def multiply_fisher_factor(self, vecs): 520 mean_vec, var_vec = self._split(vecs) 521 return (self._fisher_mean_factor * mean_vec, 522 self._fisher_var_factor * var_vec) 523 524 def multiply_fisher_factor_transpose(self, vecs): 525 mean_vec, var_vec = vecs 526 return self._concat(self._fisher_mean_factor * mean_vec, 527 self._fisher_var_factor * var_vec) 528 529 def multiply_fisher_factor_replicated_one_hot(self, index): 530 assert len(index) == 1, "Length of index was {}".format(len(index)) 531 index = index[0] 532 533 if index < int(self._mean.shape[-1]): 534 # Index corresponds to mean parameter. 535 mean_slice = self._fisher_mean_factor[:, index] 536 mean_slice = array_ops.expand_dims(mean_slice, axis=-1) 537 mean_output = insert_slice_in_zeros(mean_slice, 1, int( 538 self._mean.shape[1]), index) 539 var_output = array_ops.zeros_like(mean_output) 540 else: 541 index -= int(self._mean.shape[-1]) 542 # Index corresponds to variance parameter. 543 var_slice = self._fisher_var_factor[:, index] 544 var_slice = array_ops.expand_dims(var_slice, axis=-1) 545 var_output = insert_slice_in_zeros(var_slice, 1, 546 int(self._variance.shape[1]), index) 547 mean_output = array_ops.zeros_like(var_output) 548 549 return mean_output, var_output 550 551 @property 552 def fisher_factor_inner_shape(self): 553 return array_ops.concat( 554 [ 555 array_ops.shape(self._mean)[:-1], 556 2 * array_ops.shape(self._mean)[-1:] 557 ], 558 axis=0) 559 560 @property 561 def fisher_factor_inner_static_shape(self): 562 shape = self._mean.shape.as_list() 563 return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) 564 565 def multiply_hessian(self, vector): 566 raise NotImplementedError() 567 568 def multiply_hessian_factor(self, vector): 569 raise NotImplementedError() 570 571 def multiply_hessian_factor_transpose(self, vector): 572 raise NotImplementedError() 573 574 def multiply_hessian_factor_replicated_one_hot(self, index): 575 raise NotImplementedError() 576 577 @property 578 def hessian_factor_inner_shape(self): 579 raise NotImplementedError() 580 581 @property 582 def hessian_factor_inner_static_shape(self): 583 raise NotImplementedError() 584 585 586 class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, 587 NaturalParamsNegativeLogProbLoss): 588 """Neg log prob loss for a categorical distribution parameterized by logits. 589 590 591 Note that the Fisher (for a single case) of a categorical distribution, with 592 respect to the natural parameters (i.e. the logits), is given by: 593 594 F = diag(p) - p*p^T 595 596 where p = softmax(logits). F can be factorized as F = B * B^T where 597 598 B = diag(q) - p*q^T 599 600 where q is the entry-wise square root of p. This is easy to verify using the 601 fact that q^T*q = 1. 602 """ 603 604 def __init__(self, logits, targets=None, seed=None): 605 """Instantiates a CategoricalLogitsNegativeLogProbLoss. 606 607 Args: 608 logits: Tensor of shape [batch_size, output_size]. Parameters for 609 underlying distribution. 610 targets: None or Tensor of shape [output_size]. Each elements contains an 611 index in [0, output_size). 612 seed: int or None. Default random seed when sampling. 613 """ 614 self._logits_components = [] 615 self._targets_components = [] 616 self.register_additional_minibatch(logits, targets=targets) 617 super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) 618 619 def register_additional_minibatch(self, logits, targets=None): 620 """Register an additiona minibatch's worth of parameters. 621 622 Args: 623 logits: Tensor of shape [batch_size, output_size]. Parameters for 624 underlying distribution. 625 targets: None or Tensor of shape [batch_size, output_size]. Each row must 626 be a one-hot vector. 627 """ 628 self._logits_components.append(logits) 629 self._targets_components.append(targets) 630 631 @property 632 def _logits(self): 633 return array_ops.concat(self._logits_components, axis=0) 634 635 @property 636 def input_minibatches(self): 637 return self._logits_components 638 639 @property 640 def targets(self): 641 if all(target is None for target in self._targets_components): 642 return None 643 return array_ops.concat(self._targets_components, axis=0) 644 645 @property 646 def dist(self): 647 return categorical.Categorical(logits=self._logits) 648 649 @property 650 def _probs(self): 651 return self.dist.probs 652 653 @property 654 def _sqrt_probs(self): 655 return math_ops.sqrt(self._probs) 656 657 @property 658 def params(self): 659 return self._logits 660 661 def multiply_fisher(self, vector): 662 probs = self._probs 663 return vector * probs - probs * math_ops.reduce_sum( 664 vector * probs, axis=-1, keep_dims=True) 665 666 def multiply_fisher_factor(self, vector): 667 probs = self._probs 668 sqrt_probs = self._sqrt_probs 669 return sqrt_probs * vector - probs * math_ops.reduce_sum( 670 sqrt_probs * vector, axis=-1, keep_dims=True) 671 672 def multiply_fisher_factor_transpose(self, vector): 673 probs = self._probs 674 sqrt_probs = self._sqrt_probs 675 return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( 676 probs * vector, axis=-1, keep_dims=True) 677 678 def multiply_fisher_factor_replicated_one_hot(self, index): 679 assert len(index) == 1, "Length of index was {}".format(len(index)) 680 probs = self._probs 681 sqrt_probs = self._sqrt_probs 682 sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) 683 padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, 684 int(sqrt_probs.shape[1]), index[0]) 685 return padded_slice - probs * sqrt_probs_slice 686 687 @property 688 def fisher_factor_inner_shape(self): 689 return array_ops.shape(self._logits) 690 691 @property 692 def fisher_factor_inner_static_shape(self): 693 return self._logits.shape 694 695 696 class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, 697 NaturalParamsNegativeLogProbLoss): 698 """Neg log prob loss for multiple Bernoulli distributions param'd by logits. 699 700 Represents N independent Bernoulli distributions where N = len(logits). Its 701 Fisher Information matrix is given by, 702 703 F = diag(p * (1-p)) 704 p = sigmoid(logits) 705 706 As F is diagonal with positive entries, its factor B is, 707 708 B = diag(sqrt(p * (1-p))) 709 """ 710 711 def __init__(self, logits, targets=None, seed=None): 712 self._logits = logits 713 self._targets = targets 714 super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) 715 716 @property 717 def targets(self): 718 return self._targets 719 720 @property 721 def dist(self): 722 return bernoulli.Bernoulli(logits=self._logits) 723 724 @property 725 def _probs(self): 726 return self.dist.probs 727 728 @property 729 def params(self): 730 return self._logits 731 732 def multiply_fisher(self, vector): 733 return self._probs * (1 - self._probs) * vector 734 735 def multiply_fisher_factor(self, vector): 736 return math_ops.sqrt(self._probs * (1 - self._probs)) * vector 737 738 def multiply_fisher_factor_transpose(self, vector): 739 return self.multiply_fisher_factor(vector) # it's symmetric in this case 740 741 def multiply_fisher_factor_replicated_one_hot(self, index): 742 assert len(index) == 1, "Length of index was {}".format(len(index)) 743 probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) 744 output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) 745 return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), 746 index[0]) 747 748 @property 749 def fisher_factor_inner_shape(self): 750 return array_ops.shape(self._logits) 751 752 @property 753 def fisher_factor_inner_static_shape(self): 754 return self._logits.shape 755 756 757 def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): 758 """Inserts slice into a larger tensor of zeros. 759 760 Forms a new tensor which is the same shape as slice_to_insert, except that 761 the dimension given by 'dim' is expanded to the size given by 'dim_size'. 762 'position' determines the position (index) at which to insert the slice within 763 that dimension. 764 765 Assumes slice_to_insert.shape[dim] = 1. 766 767 Args: 768 slice_to_insert: The slice to insert. 769 dim: The dimension which to expand with zeros. 770 dim_size: The new size of the 'dim' dimension. 771 position: The position of 'slice_to_insert' in the new tensor. 772 773 Returns: 774 The new tensor. 775 776 Raises: 777 ValueError: If the slice's shape at the given dim is not 1. 778 """ 779 slice_shape = slice_to_insert.shape 780 if slice_shape[dim] != 1: 781 raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " 782 "was {}".format(dim, slice_to_insert.shape[dim])) 783 784 before = [0] * int(len(slice_shape)) 785 after = before[:] 786 before[dim] = position 787 after[dim] = dim_size - position - 1 788 789 return array_ops.pad(slice_to_insert, list(zip(before, after))) 790 791 792 class OnehotCategoricalLogitsNegativeLogProbLoss( 793 CategoricalLogitsNegativeLogProbLoss): 794 """Neg log prob loss for a categorical distribution with onehot targets. 795 796 Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying 797 distribution is OneHotCategorical as opposed to Categorical. 798 """ 799 800 @property 801 def dist(self): 802 return onehot_categorical.OneHotCategorical(logits=self._logits) 803