Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.kfac.loss_functions."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.kfac.python.ops import loss_functions
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import random_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class InsertSliceInZerosTest(test.TestCase):
     32 
     33   def testBadShape(self):
     34     bad_shaped_ones = array_ops.ones(shape=[1, 3])  # n.b. shape[1] != 1
     35     with self.assertRaises(ValueError):
     36       loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
     37 
     38   def test3d(self):
     39     input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
     40     expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
     41     op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
     42     with self.test_session() as sess:
     43       actual_output_array = sess.run(op)
     44     self.assertAllEqual(expected_output_array, actual_output_array)
     45 
     46 
     47 class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
     48 
     49   def testSample(self):
     50     """Ensure samples can be drawn."""
     51     with ops.Graph().as_default(), self.test_session() as sess:
     52       logits = np.asarray([
     53           [0., 0., 0.],  #
     54           [1., -1., 0.]
     55       ]).astype(np.float32)
     56       loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
     57           array_ops.constant(logits))
     58       sample = loss.sample(42)
     59       sample = sess.run(sample)
     60       self.assertEqual(sample.shape, (2,))
     61 
     62   def testEvaluateOnTargets(self):
     63     """Ensure log probability can be evaluated correctly."""
     64     with ops.Graph().as_default(), self.test_session() as sess:
     65       logits = np.asarray([
     66           [0., 0., 0.],  #
     67           [1., -1., 0.]
     68       ]).astype(np.float32)
     69       targets = np.asarray([2, 1]).astype(np.int32)
     70       loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
     71           array_ops.constant(logits), targets=array_ops.constant(targets))
     72       neg_log_prob = loss.evaluate()
     73       neg_log_prob = sess.run(neg_log_prob)
     74 
     75       # Calculate explicit log probability of targets.
     76       probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
     77       log_probs = np.log([
     78           probs[0, targets[0]],  #
     79           probs[1, targets[1]]
     80       ])
     81       expected_log_prob = np.sum(log_probs)
     82 
     83       self.assertAllClose(neg_log_prob, -expected_log_prob)
     84 
     85   def testEvaluateOnSample(self):
     86     """Ensure log probability of a sample can be drawn."""
     87     with ops.Graph().as_default(), self.test_session() as sess:
     88       logits = np.asarray([
     89           [0., 0., 0.],  #
     90           [1., -1., 0.]
     91       ]).astype(np.float32)
     92       loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
     93           array_ops.constant(logits))
     94       neg_log_prob = loss.evaluate_on_sample(42)
     95 
     96       # Simply ensure this doesn't crash. As the output is random, it's
     97       # difficult to say if the output is correct or not...
     98       neg_log_prob = sess.run(neg_log_prob)
     99 
    100   def testMultiMinibatchRegistration(self):
    101     """Ensure this loss function supports registering multiple minibatches."""
    102     with ops.Graph().as_default():
    103       tower_logits = []
    104       loss = None
    105       num_towers = 5
    106       for _ in range(num_towers):
    107         logits = random_ops.random_uniform(shape=[2, 3])
    108         tower_logits.append(logits)
    109         if loss is None:
    110           loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
    111         else:
    112           loss.register_additional_minibatch(logits)
    113       self.assertListEqual(loss.input_minibatches, tower_logits)
    114       self.assertEqual(loss.num_registered_minibatches, num_towers)
    115 
    116   def testMultiplyFisherSingleVector(self):
    117     with ops.Graph().as_default(), self.test_session() as sess:
    118       logits = np.array([1., 2., 3.])
    119       loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
    120 
    121       # the LossFunction.multiply_fisher docstring only says it supports the
    122       # case where the vector is the same shape as the input natural parameters
    123       # (i.e. the logits here), but here we also test leading dimensions
    124       vector = np.array([1., 2., 3.])
    125       vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
    126 
    127       probs = np.exp(logits - np.logaddexp.reduce(logits))
    128       fisher = np.diag(probs) - np.outer(probs, probs)
    129 
    130       for vector in vectors:
    131         result = loss.multiply_fisher(vector)
    132         expected_result = np.dot(vector, fisher)
    133         self.assertAllClose(expected_result, sess.run(result))
    134 
    135   def testMultiplyFisherBatch(self):
    136     with ops.Graph().as_default(), self.test_session() as sess:
    137       logits = np.array([[1., 2., 3.], [4., 6., 8.]])
    138       loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
    139 
    140       vector = np.array([[1., 2., 3.], [5., 3., 1.]])
    141 
    142       na = np.newaxis
    143       probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
    144                                                   keepdims=True))
    145       fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
    146 
    147       result = loss.multiply_fisher(vector)
    148       expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
    149       self.assertEqual(sess.run(result).shape, logits.shape)
    150       self.assertAllClose(expected_result, sess.run(result))
    151 
    152 
    153 class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
    154 
    155   def testSample(self):
    156     """Ensure samples can be drawn."""
    157     with ops.Graph().as_default(), self.test_session() as sess:
    158       logits = np.asarray([
    159           [0., 0., 0.],  #
    160           [1., -1., 0.]
    161       ]).astype(np.float32)
    162       loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
    163           array_ops.constant(logits))
    164       sample = loss.sample(42)
    165       sample = sess.run(sample)
    166       self.assertEqual(sample.shape, (2, 3))
    167 
    168   def testEvaluateOnTargets(self):
    169     """Ensure log probability can be evaluated correctly."""
    170     with ops.Graph().as_default(), self.test_session() as sess:
    171       logits = np.asarray([
    172           [0., 0., 0.],  #
    173           [1., -1., 0.]
    174       ]).astype(np.float32)
    175       targets = np.asarray([2, 1]).astype(np.int32)
    176       loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
    177           array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
    178       neg_log_prob = loss.evaluate()
    179       neg_log_prob = sess.run(neg_log_prob)
    180 
    181       # Calculate explicit log probability of targets.
    182       probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    183       log_probs = np.log([
    184           probs[0, targets[0]],  #
    185           probs[1, targets[1]]
    186       ])
    187       expected_log_prob = np.sum(log_probs)
    188 
    189       self.assertAllClose(neg_log_prob, -expected_log_prob)
    190 
    191   def testEvaluateOnSample(self):
    192     """Ensure log probability of a sample can be drawn."""
    193     with ops.Graph().as_default(), self.test_session() as sess:
    194       logits = np.asarray([
    195           [0., 0., 0.],  #
    196           [1., -1., 0.]
    197       ]).astype(np.float32)
    198       loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
    199           array_ops.constant(logits))
    200       neg_log_prob = loss.evaluate_on_sample(42)
    201 
    202       # Simply ensure this doesn't crash. As the output is random, it's
    203       # difficult to say if the output is correct or not...
    204       neg_log_prob = sess.run(neg_log_prob)
    205 
    206   def testMultiMinibatchRegistration(self):
    207     """Ensure this loss function supports registering multiple minibatches."""
    208     with ops.Graph().as_default():
    209       tower_logits = []
    210       loss = None
    211       num_towers = 5
    212       for _ in range(num_towers):
    213         logits = random_ops.random_uniform(shape=[2, 3])
    214         tower_logits.append(logits)
    215         if loss is None:
    216           loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
    217               logits)
    218         else:
    219           loss.register_additional_minibatch(logits)
    220       self.assertListEqual(loss.input_minibatches, tower_logits)
    221       self.assertEqual(loss.num_registered_minibatches, num_towers)
    222 
    223 
    224 if __name__ == "__main__":
    225   test.main()
    226