1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.kfac.loss_functions.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.kfac.python.ops import loss_functions 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import random_ops 28 from tensorflow.python.platform import test 29 30 31 class InsertSliceInZerosTest(test.TestCase): 32 33 def testBadShape(self): 34 bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 35 with self.assertRaises(ValueError): 36 loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) 37 38 def test3d(self): 39 input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) 40 expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] 41 op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) 42 with self.test_session() as sess: 43 actual_output_array = sess.run(op) 44 self.assertAllEqual(expected_output_array, actual_output_array) 45 46 47 class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): 48 49 def testSample(self): 50 """Ensure samples can be drawn.""" 51 with ops.Graph().as_default(), self.test_session() as sess: 52 logits = np.asarray([ 53 [0., 0., 0.], # 54 [1., -1., 0.] 55 ]).astype(np.float32) 56 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( 57 array_ops.constant(logits)) 58 sample = loss.sample(42) 59 sample = sess.run(sample) 60 self.assertEqual(sample.shape, (2,)) 61 62 def testEvaluateOnTargets(self): 63 """Ensure log probability can be evaluated correctly.""" 64 with ops.Graph().as_default(), self.test_session() as sess: 65 logits = np.asarray([ 66 [0., 0., 0.], # 67 [1., -1., 0.] 68 ]).astype(np.float32) 69 targets = np.asarray([2, 1]).astype(np.int32) 70 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( 71 array_ops.constant(logits), targets=array_ops.constant(targets)) 72 neg_log_prob = loss.evaluate() 73 neg_log_prob = sess.run(neg_log_prob) 74 75 # Calculate explicit log probability of targets. 76 probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) 77 log_probs = np.log([ 78 probs[0, targets[0]], # 79 probs[1, targets[1]] 80 ]) 81 expected_log_prob = np.sum(log_probs) 82 83 self.assertAllClose(neg_log_prob, -expected_log_prob) 84 85 def testEvaluateOnSample(self): 86 """Ensure log probability of a sample can be drawn.""" 87 with ops.Graph().as_default(), self.test_session() as sess: 88 logits = np.asarray([ 89 [0., 0., 0.], # 90 [1., -1., 0.] 91 ]).astype(np.float32) 92 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( 93 array_ops.constant(logits)) 94 neg_log_prob = loss.evaluate_on_sample(42) 95 96 # Simply ensure this doesn't crash. As the output is random, it's 97 # difficult to say if the output is correct or not... 98 neg_log_prob = sess.run(neg_log_prob) 99 100 def testMultiMinibatchRegistration(self): 101 """Ensure this loss function supports registering multiple minibatches.""" 102 with ops.Graph().as_default(): 103 tower_logits = [] 104 loss = None 105 num_towers = 5 106 for _ in range(num_towers): 107 logits = random_ops.random_uniform(shape=[2, 3]) 108 tower_logits.append(logits) 109 if loss is None: 110 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) 111 else: 112 loss.register_additional_minibatch(logits) 113 self.assertListEqual(loss.input_minibatches, tower_logits) 114 self.assertEqual(loss.num_registered_minibatches, num_towers) 115 116 def testMultiplyFisherSingleVector(self): 117 with ops.Graph().as_default(), self.test_session() as sess: 118 logits = np.array([1., 2., 3.]) 119 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) 120 121 # the LossFunction.multiply_fisher docstring only says it supports the 122 # case where the vector is the same shape as the input natural parameters 123 # (i.e. the logits here), but here we also test leading dimensions 124 vector = np.array([1., 2., 3.]) 125 vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] 126 127 probs = np.exp(logits - np.logaddexp.reduce(logits)) 128 fisher = np.diag(probs) - np.outer(probs, probs) 129 130 for vector in vectors: 131 result = loss.multiply_fisher(vector) 132 expected_result = np.dot(vector, fisher) 133 self.assertAllClose(expected_result, sess.run(result)) 134 135 def testMultiplyFisherBatch(self): 136 with ops.Graph().as_default(), self.test_session() as sess: 137 logits = np.array([[1., 2., 3.], [4., 6., 8.]]) 138 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) 139 140 vector = np.array([[1., 2., 3.], [5., 3., 1.]]) 141 142 na = np.newaxis 143 probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, 144 keepdims=True)) 145 fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] 146 147 result = loss.multiply_fisher(vector) 148 expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] 149 self.assertEqual(sess.run(result).shape, logits.shape) 150 self.assertAllClose(expected_result, sess.run(result)) 151 152 153 class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): 154 155 def testSample(self): 156 """Ensure samples can be drawn.""" 157 with ops.Graph().as_default(), self.test_session() as sess: 158 logits = np.asarray([ 159 [0., 0., 0.], # 160 [1., -1., 0.] 161 ]).astype(np.float32) 162 loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( 163 array_ops.constant(logits)) 164 sample = loss.sample(42) 165 sample = sess.run(sample) 166 self.assertEqual(sample.shape, (2, 3)) 167 168 def testEvaluateOnTargets(self): 169 """Ensure log probability can be evaluated correctly.""" 170 with ops.Graph().as_default(), self.test_session() as sess: 171 logits = np.asarray([ 172 [0., 0., 0.], # 173 [1., -1., 0.] 174 ]).astype(np.float32) 175 targets = np.asarray([2, 1]).astype(np.int32) 176 loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( 177 array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) 178 neg_log_prob = loss.evaluate() 179 neg_log_prob = sess.run(neg_log_prob) 180 181 # Calculate explicit log probability of targets. 182 probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) 183 log_probs = np.log([ 184 probs[0, targets[0]], # 185 probs[1, targets[1]] 186 ]) 187 expected_log_prob = np.sum(log_probs) 188 189 self.assertAllClose(neg_log_prob, -expected_log_prob) 190 191 def testEvaluateOnSample(self): 192 """Ensure log probability of a sample can be drawn.""" 193 with ops.Graph().as_default(), self.test_session() as sess: 194 logits = np.asarray([ 195 [0., 0., 0.], # 196 [1., -1., 0.] 197 ]).astype(np.float32) 198 loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( 199 array_ops.constant(logits)) 200 neg_log_prob = loss.evaluate_on_sample(42) 201 202 # Simply ensure this doesn't crash. As the output is random, it's 203 # difficult to say if the output is correct or not... 204 neg_log_prob = sess.run(neg_log_prob) 205 206 def testMultiMinibatchRegistration(self): 207 """Ensure this loss function supports registering multiple minibatches.""" 208 with ops.Graph().as_default(): 209 tower_logits = [] 210 loss = None 211 num_towers = 5 212 for _ in range(num_towers): 213 logits = random_ops.random_uniform(shape=[2, 3]) 214 tower_logits.append(logits) 215 if loss is None: 216 loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( 217 logits) 218 else: 219 loss.register_additional_minibatch(logits) 220 self.assertListEqual(loss.input_minibatches, tower_logits) 221 self.assertEqual(loss.num_registered_minibatches, num_towers) 222 223 224 if __name__ == "__main__": 225 test.main() 226