Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.kfac.fisher_blocks."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
     24 from tensorflow.contrib.kfac.python.ops import layer_collection as lc
     25 from tensorflow.contrib.kfac.python.ops import utils
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import random_seed
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import linalg_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops import state_ops
     33 from tensorflow.python.ops import variables as tf_variables
     34 from tensorflow.python.platform import test
     35 
     36 
     37 def _make_psd(dim):
     38   """Constructs a PSD matrix of the given dimension."""
     39   mat = np.ones((dim, dim), dtype=np.float32)
     40   mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
     41   return array_ops.constant(mat)
     42 
     43 
     44 class UtilsTest(test.TestCase):
     45 
     46   def testComputePiTracenorm(self):
     47     with ops.Graph().as_default(), self.test_session() as sess:
     48       random_seed.set_random_seed(200)
     49       left_factor = array_ops.diag([1., 2., 0., 1.])
     50       right_factor = array_ops.ones([2., 2.])
     51 
     52       # pi is the sqrt of the left trace norm divided by the right trace norm
     53       pi = fb.compute_pi_tracenorm(left_factor, right_factor)
     54 
     55       pi_val = sess.run(pi)
     56       self.assertEqual(1., pi_val)
     57 
     58 
     59 class FullFBTest(test.TestCase):
     60 
     61   def testFullFBInitSingleTensor(self):
     62     with ops.Graph().as_default():
     63       random_seed.set_random_seed(200)
     64       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
     65       block = fb.FullFB(lc.LayerCollection(), params)
     66       block.register_additional_minibatch(32)
     67 
     68       self.assertAllEqual(params, block.tensors_to_compute_grads())
     69 
     70   def testFullFBInitTensorTuple(self):
     71     with ops.Graph().as_default():
     72       random_seed.set_random_seed(200)
     73       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
     74       block = fb.FullFB(lc.LayerCollection(), params)
     75       block.register_additional_minibatch(32)
     76 
     77       self.assertAllEqual(params, block.tensors_to_compute_grads())
     78 
     79   def testInstantiateFactors(self):
     80     with ops.Graph().as_default():
     81       random_seed.set_random_seed(200)
     82       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
     83       block = fb.FullFB(lc.LayerCollection(), params)
     84       block.register_additional_minibatch(32)
     85 
     86       grads = (params[0]**2, math_ops.sqrt(params[1]))
     87       block.instantiate_factors(grads, 0.5)
     88 
     89   def testMultiplyInverseTuple(self):
     90     with ops.Graph().as_default(), self.test_session() as sess:
     91       random_seed.set_random_seed(200)
     92       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
     93       block = fb.FullFB(lc.LayerCollection(), params)
     94       block.register_additional_minibatch(32)
     95       grads = (params[0]**2, math_ops.sqrt(params[1]))
     96       block.instantiate_factors((grads,), 0.5)
     97 
     98       # Make sure our inverse is something other than the identity.
     99       sess.run(tf_variables.global_variables_initializer())
    100       sess.run(block._factor.make_inverse_update_ops())
    101 
    102       vector = array_ops.ones(3,) * 2
    103       output = block.multiply_inverse(vector)
    104 
    105       self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
    106 
    107   def testMultiplyInverseNotTuple(self):
    108     with ops.Graph().as_default(), self.test_session() as sess:
    109       random_seed.set_random_seed(200)
    110       params = array_ops.constant([[1.], [2.]])
    111       block = fb.FullFB(lc.LayerCollection(), params)
    112       block.register_additional_minibatch(32)
    113       grads = params**2
    114       block.instantiate_factors((grads,), 0.5)
    115 
    116       # Make sure our inverse is something other than the identity.
    117       sess.run(tf_variables.global_variables_initializer())
    118       sess.run(block._factor.make_inverse_update_ops())
    119 
    120       vector = array_ops.ones(2,) * 2
    121       output = block.multiply_inverse(vector)
    122 
    123       self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
    124 
    125   def testMultiplyInverseAgainstExplicit(self):
    126     with ops.Graph().as_default(), self.test_session() as sess:
    127       random_seed.set_random_seed(200)
    128       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    129       block = fb.FullFB(lc.LayerCollection(), params)
    130       block.register_additional_minibatch(32)
    131       grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
    132       damping = 0.5
    133       block.instantiate_factors((grads,), damping)
    134 
    135       # Make sure our inverse is something other than the identity.
    136       sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
    137       sess.run(block._factor.make_inverse_update_ops())
    138 
    139       v_flat = np.array([4., 5., 6.], dtype=np.float32)
    140       vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
    141       output = block.multiply_inverse(vector)
    142       output_flat = sess.run(utils.tensors_to_column(output)).ravel()
    143 
    144       full = sess.run(block.full_fisher_block())
    145       explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
    146 
    147       self.assertAllClose(output_flat, explicit)
    148 
    149 
    150 class NaiveDiagonalFBTest(test.TestCase):
    151 
    152   def testNaiveDiagonalFBInitSingleTensor(self):
    153     with ops.Graph().as_default():
    154       random_seed.set_random_seed(200)
    155       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    156       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    157       block.register_additional_minibatch(32)
    158 
    159       self.assertAllEqual(params, block.tensors_to_compute_grads())
    160 
    161   def testNaiveDiagonalFBInitTensorTuple(self):
    162     with ops.Graph().as_default():
    163       random_seed.set_random_seed(200)
    164       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    165       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    166       block.register_additional_minibatch(32)
    167 
    168       self.assertAllEqual(params, block.tensors_to_compute_grads())
    169 
    170   def testInstantiateFactors(self):
    171     with ops.Graph().as_default():
    172       random_seed.set_random_seed(200)
    173       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    174       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    175       block.register_additional_minibatch(32)
    176 
    177       grads = (params[0]**2, math_ops.sqrt(params[1]))
    178       block.instantiate_factors(grads, 0.5)
    179 
    180   def testMultiplyInverseTuple(self):
    181     with ops.Graph().as_default(), self.test_session() as sess:
    182       random_seed.set_random_seed(200)
    183       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    184       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    185       block.register_additional_minibatch(32)
    186       grads = (params[0]**2, math_ops.sqrt(params[1]))
    187       block.instantiate_factors((grads,), 0.5)
    188 
    189       # Make sure our inverse is something other than the identity.
    190       sess.run(tf_variables.global_variables_initializer())
    191       sess.run(block._factor.make_inverse_update_ops())
    192 
    193       vector = array_ops.ones(3,) * 2
    194       output = block.multiply_inverse(vector)
    195 
    196       self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
    197 
    198   def testMultiplyInverseNotTuple(self):
    199     with ops.Graph().as_default(), self.test_session() as sess:
    200       random_seed.set_random_seed(200)
    201       params = array_ops.constant([[1.], [2.]])
    202       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    203       block.register_additional_minibatch(32)
    204       grads = params**2
    205       block.instantiate_factors((grads,), 0.5)
    206 
    207       # Make sure our inverse is something other than the identity.
    208       sess.run(tf_variables.global_variables_initializer())
    209       sess.run(block._factor.make_inverse_update_ops())
    210       vector = array_ops.ones(2,) * 2
    211       output = block.multiply_inverse(vector)
    212 
    213       self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
    214 
    215   def testMultiplyInverseAgainstExplicit(self):
    216     with ops.Graph().as_default(), self.test_session() as sess:
    217       random_seed.set_random_seed(200)
    218       params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
    219       block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
    220       block.register_additional_minibatch(32)
    221       grads = (params[0]**2, math_ops.sqrt(params[1]))
    222       damping = 0.5
    223       block.instantiate_factors((grads,), damping)
    224 
    225       cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
    226       sess.run(state_ops.assign(block._factor._cov, cov))
    227       sess.run(block._factor.make_inverse_update_ops())
    228 
    229       v_flat = np.array([4., 5., 6.], dtype=np.float32)
    230       vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
    231       output = block.multiply_inverse(vector)
    232       output_flat = sess.run(utils.tensors_to_column(output)).ravel()
    233 
    234       full = sess.run(block.full_fisher_block())
    235       explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
    236 
    237       self.assertAllClose(output_flat, explicit)
    238 
    239 
    240 class FullyConnectedDiagonalFBTest(test.TestCase):
    241 
    242   def setUp(self):
    243     super(FullyConnectedDiagonalFBTest, self).setUp()
    244 
    245     self.batch_size = 4
    246     self.input_size = 6
    247     self.output_size = 3
    248 
    249     self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
    250         np.float32)
    251     self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
    252         np.float32)
    253     self.output_grads = np.random.randn(self.batch_size,
    254                                         self.output_size).astype(np.float32)
    255     self.w = np.random.randn(self.input_size, self.output_size).astype(
    256         np.float32)
    257     self.b = np.random.randn(self.output_size).astype(np.float32)
    258 
    259   def fisherApprox(self, has_bias=False):
    260     """Fisher approximation using default inputs."""
    261     if has_bias:
    262       inputs = np.concatenate(
    263           [self.inputs, np.ones([self.batch_size, 1])], axis=1)
    264     else:
    265       inputs = self.inputs
    266     return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
    267 
    268   def buildDiagonalFisherApproximation(self, inputs, output_grads):
    269     """Builds explicit diagonal Fisher approximation.
    270 
    271     Fisher's diagonal is (d loss / d w)'s elements squared for
    272       d/dw = E[outer(input, output_grad)]
    273 
    274     where the expectation is taken over examples.
    275 
    276     Args:
    277       inputs: np.array of shape [batch_size, input_size].
    278       output_grads: np.array of shape [batch_size, output_size].
    279 
    280     Returns:
    281       Diagonal np.array of shape [num_params, num_params] for num_params =
    282       input_size * output_size.
    283     """
    284     batch_size = inputs.shape[0]
    285     assert output_grads.shape[0] == batch_size
    286     input_size = inputs.shape[1]
    287     output_size = output_grads.shape[1]
    288     fisher_diag = np.zeros((input_size, output_size))
    289     for i in range(batch_size):
    290       fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
    291     return np.diag(fisher_diag.flatten()) / batch_size
    292 
    293   def testMultiply(self):
    294     result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
    295                                        [self.output_grads])
    296 
    297     # Construct Fisher-vector product.
    298     expected_result = self.fisherApprox().dot(self.w.flatten())
    299     expected_result = expected_result.reshape(
    300         [self.input_size, self.output_size])
    301 
    302     self.assertAllClose(expected_result, result)
    303 
    304   def testMultiplyInverse(self):
    305     _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
    306                                        [self.output_grads])
    307 
    308     # Construct inverse Fisher-vector product.
    309     expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
    310     expected_result = expected_result.reshape(
    311         [self.input_size, self.output_size])
    312 
    313     self.assertAllClose(expected_result, result)
    314 
    315   def testRegisterAdditionalMinibatch(self):
    316     """Ensure 1 big minibatch and 2 small minibatches are equivalent."""
    317     multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
    318         self.w, [self.inputs], [self.outputs], [self.output_grads])
    319     multiply_result_small, multiply_inverse_result_small = (
    320         self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
    321                                np.split(self.outputs, 2),
    322                                np.split(self.output_grads, 2)))
    323 
    324     self.assertAllClose(multiply_result_big, multiply_result_small)
    325     self.assertAllClose(multiply_inverse_result_big,
    326                         multiply_inverse_result_small)
    327 
    328   def testMultiplyHasBias(self):
    329     result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
    330                                        [self.outputs], [self.output_grads])
    331     expected_result = self.fisherApprox(True).dot(
    332         np.concatenate([self.w.flatten(), self.b.flatten()]))
    333     expected_result = expected_result.reshape(
    334         [self.input_size + 1, self.output_size])
    335     expected_result = (expected_result[:-1], expected_result[-1])
    336 
    337     self.assertEqual(len(result), 2)
    338     self.assertAllClose(expected_result[0], result[0])
    339     self.assertAllClose(expected_result[1], result[1])
    340 
    341   def runFisherBlockOps(self, params, inputs, outputs, output_grads):
    342     """Run Ops guaranteed by FisherBlock interface.
    343 
    344     Args:
    345       params: Tensor or 2-tuple of Tensors. Represents weights or weights and
    346         bias of this layer.
    347       inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
    348         layer.
    349       outputs: list of Tensors of shape [batch_size, output_size].
    350         Preactivations produced by layer.
    351       output_grads: list of Tensors of shape [batch_size, output_size].
    352         Gradient of loss with respect to 'outputs'.
    353 
    354     Returns:
    355       multiply_result: Result of FisherBlock.multiply(params)
    356       multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
    357     """
    358     with ops.Graph().as_default(), self.test_session() as sess:
    359       inputs = as_tensors(inputs)
    360       outputs = as_tensors(outputs)
    361       output_grads = as_tensors(output_grads)
    362       params = as_tensors(params)
    363 
    364       block = fb.FullyConnectedDiagonalFB(
    365           lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
    366       for (i, o) in zip(inputs, outputs):
    367         block.register_additional_minibatch(i, o)
    368 
    369       block.instantiate_factors((output_grads,), damping=0.0)
    370 
    371       sess.run(tf_variables.global_variables_initializer())
    372       sess.run(block._factor.make_covariance_update_op(0.0))
    373       multiply_result = sess.run(block.multiply(params))
    374       multiply_inverse_result = sess.run(block.multiply_inverse(params))
    375 
    376     return multiply_result, multiply_inverse_result
    377 
    378 
    379 class EmbeddingKFACFBTest(test.TestCase):
    380 
    381   def testInstantiateFactors(self):
    382     with ops.Graph().as_default():
    383       random_seed.set_random_seed(200)
    384 
    385       # Create a Fisher Block.
    386       vocab_size = 5
    387       block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
    388 
    389       # Add some examples.
    390       inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
    391       outputs = array_ops.constant([[0.], [1.], [2.]])
    392       block.register_additional_minibatch(inputs, outputs)
    393 
    394       # Instantiate factor's variables. Ensure it doesn't fail.
    395       grads = outputs**2.
    396       damping = array_ops.constant(0.)
    397       block.instantiate_factors(([grads],), damping)
    398 
    399   def testMultiplyInverse(self):
    400     with ops.Graph().as_default(), self.test_session() as sess:
    401       random_seed.set_random_seed(200)
    402 
    403       # Create a Fisher Block.
    404       vocab_size = 5
    405       block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
    406 
    407       # Add some examples.
    408       inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
    409       outputs = array_ops.constant([[0.], [1.], [2.]])
    410       block.register_additional_minibatch(inputs, outputs)
    411 
    412       # Instantiate factor's variables. Ensure it doesn't fail.
    413       grads = outputs**2.
    414       damping = array_ops.constant(0.)
    415       block.instantiate_factors(([grads],), damping)
    416 
    417       # Create a sparse update.
    418       indices = array_ops.constant([1, 3, 4])
    419       values = array_ops.constant([[1.], [1.], [1.]])
    420       sparse_vector = ops.IndexedSlices(
    421           values, indices, dense_shape=[vocab_size, 1])
    422       dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
    423 
    424       # Compare Fisher-vector product against explicit result.
    425       result = block.multiply_inverse(sparse_vector)
    426       expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
    427                                                 dense_vector)
    428 
    429       sess.run(tf_variables.global_variables_initializer())
    430       self.assertAlmostEqual(
    431           sess.run(expected_result[1]), sess.run(result.values[0]))
    432       self.assertAlmostEqual(
    433           sess.run(expected_result[3]), sess.run(result.values[1]))
    434       self.assertAlmostEqual(
    435           sess.run(expected_result[4]), sess.run(result.values[2]))
    436 
    437 
    438 class FullyConnectedKFACBasicFBTest(test.TestCase):
    439 
    440   def testFullyConnectedKFACBasicFBInit(self):
    441     with ops.Graph().as_default():
    442       random_seed.set_random_seed(200)
    443       inputs = array_ops.constant([1., 2.])
    444       outputs = array_ops.constant([3., 4.])
    445       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
    446       block.register_additional_minibatch(inputs, outputs)
    447 
    448       self.assertAllEqual([outputs], block.tensors_to_compute_grads())
    449 
    450   def testInstantiateFactorsHasBias(self):
    451     with ops.Graph().as_default():
    452       random_seed.set_random_seed(200)
    453       inputs = array_ops.constant([[1., 2.], [3., 4.]])
    454       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    455       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
    456       block.register_additional_minibatch(inputs, outputs)
    457 
    458       grads = outputs**2
    459       block.instantiate_factors(([grads],), 0.5)
    460 
    461   def testInstantiateFactorsNoBias(self):
    462     with ops.Graph().as_default():
    463       random_seed.set_random_seed(200)
    464       inputs = array_ops.constant([[1., 2.], [3., 4.]])
    465       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    466       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
    467       block.register_additional_minibatch(inputs, outputs)
    468 
    469       grads = outputs**2
    470       block.instantiate_factors(([grads],), 0.5)
    471 
    472   def testMultiplyInverseTuple(self):
    473     with ops.Graph().as_default(), self.test_session() as sess:
    474       random_seed.set_random_seed(200)
    475       inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
    476       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    477       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
    478       block.register_additional_minibatch(inputs, outputs)
    479       grads = outputs**2
    480       block.instantiate_factors(([grads],), 0.5)
    481 
    482       # Make sure our inverse is something other than the identity.
    483       sess.run(tf_variables.global_variables_initializer())
    484       sess.run(block._input_factor.make_inverse_update_ops())
    485       sess.run(block._output_factor.make_inverse_update_ops())
    486 
    487       vector = (
    488           np.arange(2, 6).reshape(2, 2).astype(np.float32),  #
    489           np.arange(1, 3).reshape(2, 1).astype(np.float32))
    490       output = block.multiply_inverse((array_ops.constant(vector[0]),
    491                                        array_ops.constant(vector[1])))
    492 
    493       output = sess.run(output)
    494       self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
    495                           output[0])
    496       self.assertAllClose([0.343146, 0.686291], output[1])
    497 
    498   def testMultiplyInverseNotTuple(self):
    499     with ops.Graph().as_default(), self.test_session() as sess:
    500       random_seed.set_random_seed(200)
    501       inputs = array_ops.constant([[1., 2.], [3., 4.]])
    502       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    503       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
    504       block.register_additional_minibatch(inputs, outputs)
    505       grads = outputs**2
    506       block.instantiate_factors(([grads],), 0.5)
    507 
    508       # Make sure our inverse is something other than the identity.
    509       sess.run(tf_variables.global_variables_initializer())
    510       sess.run(block._input_factor.make_inverse_update_ops())
    511       sess.run(block._output_factor.make_inverse_update_ops())
    512 
    513       vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
    514       output = block.multiply_inverse(array_ops.constant(vector))
    515 
    516       self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
    517                           sess.run(output))
    518 
    519   def testMultiplyInverseAgainstExplicit(self):
    520     with ops.Graph().as_default(), self.test_session() as sess:
    521       random_seed.set_random_seed(200)
    522       input_dim, output_dim = 3, 2
    523       inputs = array_ops.zeros([32, input_dim])
    524       outputs = array_ops.zeros([32, output_dim])
    525       params = array_ops.zeros([input_dim, output_dim])
    526       block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
    527       block.register_additional_minibatch(inputs, outputs)
    528       grads = outputs**2
    529       damping = 0.  # This test is only valid without damping.
    530       block.instantiate_factors(([grads],), damping)
    531 
    532       sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
    533       sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
    534       sess.run(block._input_factor.make_inverse_update_ops())
    535       sess.run(block._output_factor.make_inverse_update_ops())
    536 
    537       v_flat = np.arange(6, dtype=np.float32)
    538       vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
    539       output = block.multiply_inverse(vector)
    540       output_flat = sess.run(utils.tensors_to_column(output)).ravel()
    541 
    542       full = sess.run(block.full_fisher_block())
    543       explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
    544 
    545       self.assertAllClose(output_flat, explicit)
    546 
    547 
    548 class ConvDiagonalFBTest(test.TestCase):
    549 
    550   def setUp(self):
    551     super(ConvDiagonalFBTest, self).setUp()
    552 
    553     self.batch_size = 2
    554     self.height = 8
    555     self.width = 4
    556     self.input_channels = 6
    557     self.output_channels = 3
    558     self.kernel_size = 1
    559 
    560     self.inputs = np.random.randn(self.batch_size, self.height, self.width,
    561                                   self.input_channels).astype(np.float32)
    562     self.outputs = np.zeros(
    563         [self.batch_size, self.height, self.width,
    564          self.output_channels]).astype(np.float32)
    565     self.output_grads = np.random.randn(
    566         self.batch_size, self.height, self.width, self.output_channels).astype(
    567             np.float32)
    568     self.w = np.random.randn(self.kernel_size, self.kernel_size,
    569                              self.input_channels, self.output_channels).astype(
    570                                  np.float32)
    571     self.b = np.random.randn(self.output_channels).astype(np.float32)
    572 
    573   def fisherApprox(self, has_bias=False):
    574     """Fisher approximation using default inputs."""
    575     if has_bias:
    576       inputs = np.concatenate(
    577           [self.inputs,
    578            np.ones([self.batch_size, self.height, self.width, 1])],
    579           axis=-1)
    580     else:
    581       inputs = self.inputs
    582     return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
    583                                                  self.kernel_size)
    584 
    585   def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
    586     r"""Builds explicit diagonal Fisher approximation.
    587 
    588     Fisher's diagonal is (d loss / d w)'s elements squared for
    589       d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
    590 
    591     where the expectation is taken over examples and the sum over (x, y)
    592     locations upon which the convolution is applied.
    593 
    594     Args:
    595       inputs: np.array of shape [batch_size, height, width, input_channels].
    596       output_grads: np.array of shape [batch_size, height, width,
    597         output_channels].
    598       kernel_size: int. height and width of kernel.
    599 
    600     Returns:
    601       Diagonal np.array of shape [num_params, num_params] for num_params =
    602       kernel_size^2 * input_channels * output_channels.
    603     """
    604     batch_size, height, width, input_channels = inputs.shape
    605     assert output_grads.shape[0] == batch_size
    606     assert output_grads.shape[1] == height
    607     assert output_grads.shape[2] == width
    608     output_channels = output_grads.shape[3]
    609 
    610     # If kernel_size == 1, then we don't need to worry about capturing context
    611     # around the pixel upon which a convolution is applied. This makes testing
    612     # easier.
    613     assert kernel_size == 1, "kernel_size != 1 isn't supported."
    614     num_locations = height * width
    615     inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
    616     output_grads = np.reshape(output_grads,
    617                               [batch_size, num_locations, output_channels])
    618 
    619     fisher_diag = np.zeros((input_channels, output_channels))
    620     for i in range(batch_size):
    621       # Each example's approximation is a square(sum-of-outer-products).
    622       example_fisher_diag = np.zeros((input_channels, output_channels))
    623       for j in range(num_locations):
    624         example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
    625       fisher_diag += np.square(example_fisher_diag)
    626 
    627     # Normalize by batch_size (not num_locations).
    628     return np.diag(fisher_diag.flatten()) / batch_size
    629 
    630   def testMultiply(self):
    631     result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
    632                                        [self.output_grads])
    633 
    634     # Construct Fisher-vector product.
    635     expected_result = self.fisherApprox().dot(self.w.flatten())
    636     expected_result = expected_result.reshape([
    637         self.kernel_size, self.kernel_size, self.input_channels,
    638         self.output_channels
    639     ])
    640 
    641     self.assertAllClose(expected_result, result)
    642 
    643   def testMultiplyInverse(self):
    644     _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
    645                                        [self.output_grads])
    646 
    647     # Construct inverse Fisher-vector product.
    648     expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
    649     expected_result = expected_result.reshape([
    650         self.kernel_size, self.kernel_size, self.input_channels,
    651         self.output_channels
    652     ])
    653 
    654     self.assertAllClose(expected_result, result, atol=1e-3)
    655 
    656   def testRegisterAdditionalMinibatch(self):
    657     """Ensure 1 big minibatch and 2 small minibatches are equivalent."""
    658     multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
    659         self.w, [self.inputs], [self.outputs], [self.output_grads])
    660     multiply_result_small, multiply_inverse_result_small = (
    661         self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
    662                                np.split(self.outputs, 2),
    663                                np.split(self.output_grads, 2)))
    664 
    665     self.assertAllClose(multiply_result_big, multiply_result_small)
    666     self.assertAllClose(multiply_inverse_result_big,
    667                         multiply_inverse_result_small)
    668 
    669   def testMultiplyHasBias(self):
    670     result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
    671                                        [self.outputs], [self.output_grads])
    672     # Clone 'b' along 'input_channels' dimension.
    673     b_filter = np.tile(
    674         np.reshape(self.b, [1, 1, 1, self.output_channels]),
    675         [self.kernel_size, self.kernel_size, 1, 1])
    676     params = np.concatenate([self.w, b_filter], axis=2)
    677     expected_result = self.fisherApprox(True).dot(params.flatten())
    678 
    679     # Extract 'b' from concatenated parameters.
    680     expected_result = expected_result.reshape([
    681         self.kernel_size, self.kernel_size, self.input_channels + 1,
    682         self.output_channels
    683     ])
    684     expected_result = (expected_result[:, :, 0:-1, :],
    685                        np.reshape(expected_result[:, :, -1, :],
    686                                   [self.output_channels]))
    687 
    688     self.assertEqual(len(result), 2)
    689     self.assertAllClose(expected_result[0], result[0])
    690     self.assertAllClose(expected_result[1], result[1])
    691 
    692   def runFisherBlockOps(self, params, inputs, outputs, output_grads):
    693     """Run Ops guaranteed by FisherBlock interface.
    694 
    695     Args:
    696       params: Tensor or 2-tuple of Tensors. Represents weights or weights and
    697         bias of this layer.
    698       inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
    699         layer.
    700       outputs: list of Tensors of shape [batch_size, output_size].
    701         Preactivations produced by layer.
    702       output_grads: list of Tensors of shape [batch_size, output_size].
    703         Gradient of loss with respect to 'outputs'.
    704 
    705     Returns:
    706       multiply_result: Result of FisherBlock.multiply(params)
    707       multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
    708     """
    709     with ops.Graph().as_default(), self.test_session() as sess:
    710       inputs = as_tensors(inputs)
    711       outputs = as_tensors(outputs)
    712       output_grads = as_tensors(output_grads)
    713       params = as_tensors(params)
    714 
    715       block = fb.ConvDiagonalFB(
    716           lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
    717       for (i, o) in zip(inputs, outputs):
    718         block.register_additional_minibatch(i, o)
    719 
    720       block.instantiate_factors((output_grads,), damping=0.0)
    721 
    722       sess.run(tf_variables.global_variables_initializer())
    723       sess.run(block._factor.make_covariance_update_op(0.0))
    724       multiply_result = sess.run(block.multiply(params))
    725       multiply_inverse_result = sess.run(block.multiply_inverse(params))
    726 
    727     return multiply_result, multiply_inverse_result
    728 
    729 
    730 class ConvKFCBasicFBTest(test.TestCase):
    731 
    732   def _testConvKFCBasicFBInitParams(self, params):
    733     with ops.Graph().as_default():
    734       random_seed.set_random_seed(200)
    735       if isinstance(params, (list, tuple)):
    736         params = [array_ops.constant(param) for param in params]
    737       else:
    738         params = array_ops.constant(params)
    739       inputs = random_ops.random_normal((2, 2, 2))
    740       outputs = random_ops.random_normal((2, 2, 2))
    741       block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
    742       block.register_additional_minibatch(inputs, outputs)
    743 
    744       self.assertAllEqual([outputs], block.tensors_to_compute_grads())
    745 
    746   def testConvKFCBasicFBInitParamsParamsTuple(self):
    747     self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
    748 
    749   def testConvKFCBasicFBInitParamsParamsSingle(self):
    750     self._testConvKFCBasicFBInitParams([np.array([1., 2.])])
    751 
    752   def testMultiplyInverseTuple(self):
    753     with ops.Graph().as_default(), self.test_session() as sess:
    754       random_seed.set_random_seed(200)
    755       params = random_ops.random_normal((2, 2, 2, 2))
    756       inputs = random_ops.random_normal((2, 2, 2, 2))
    757       outputs = random_ops.random_normal((2, 2, 2, 2))
    758       block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
    759                                 'SAME')
    760       block.register_additional_minibatch(inputs, outputs)
    761       grads = outputs**2
    762       block.instantiate_factors(([grads],), 0.5)
    763 
    764       # Make sure our inverse is something other than the identity.
    765       sess.run(tf_variables.global_variables_initializer())
    766       sess.run(block._input_factor.make_inverse_update_ops())
    767       sess.run(block._output_factor.make_inverse_update_ops())
    768 
    769       vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
    770                 np.arange(2, 4).reshape(2, 1).astype(np.float32))
    771       output = block.multiply_inverse((array_ops.constant(vector[0]),
    772                                        array_ops.constant(vector[1])))
    773 
    774       output = sess.run(output)
    775       self.assertAllClose([0.136455, 0.27291], output[0][0])
    776       self.assertAllClose([0.27291, 0.409365], output[1])
    777 
    778   def testMultiplyInverseNotTuple(self):
    779     with ops.Graph().as_default(), self.test_session() as sess:
    780       random_seed.set_random_seed(200)
    781       params = random_ops.random_normal((2, 2, 2, 2))
    782       inputs = random_ops.random_normal((2, 2, 2, 2))
    783       outputs = random_ops.random_normal((2, 2, 2, 2))
    784       block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
    785                                 'SAME')
    786       block.register_additional_minibatch(inputs, outputs)
    787       self.assertFalse(block._has_bias)
    788       grads = outputs**2
    789       block.instantiate_factors(([grads],), 0.5)
    790 
    791       # Make sure our inverse is something other than the identity.
    792       sess.run(tf_variables.global_variables_initializer())
    793       sess.run(block._input_factor.make_inverse_update_ops())
    794       sess.run(block._output_factor.make_inverse_update_ops())
    795 
    796       vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
    797       output = block.multiply_inverse(array_ops.constant(vector))
    798 
    799       self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
    800 
    801   def testMultiplyInverseNotTupleWithBias(self):
    802     with ops.Graph().as_default(), self.test_session() as sess:
    803       random_seed.set_random_seed(200)
    804       params = [random_ops.random_normal((2, 2, 2, 2))]
    805       inputs = random_ops.random_normal((2, 2, 2, 2))
    806       outputs = random_ops.random_normal((2, 2, 2, 2))
    807       block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
    808                                 'SAME')
    809       block.register_additional_minibatch(inputs, outputs)
    810       self.assertTrue(block._has_bias)
    811       grads = outputs**2
    812       block.instantiate_factors(([grads],), 0.5)
    813 
    814       # Make sure our inverse is something other than the identity.
    815       sess.run(tf_variables.global_variables_initializer())
    816       sess.run(block._input_factor.make_inverse_update_ops())
    817       sess.run(block._output_factor.make_inverse_update_ops())
    818 
    819       vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
    820       output = block.multiply_inverse(array_ops.constant(vector))
    821 
    822       self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
    823 
    824   def testMultiplyInverseAgainstExplicit(self):
    825     with ops.Graph().as_default(), self.test_session() as sess:
    826       random_seed.set_random_seed(200)
    827       params = array_ops.zeros((2, 2, 2, 2))
    828       inputs = array_ops.zeros((2, 2, 2, 2))
    829       outputs = array_ops.zeros((2, 2, 2, 2))
    830       block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
    831                                 'SAME')
    832       block.register_additional_minibatch(inputs, outputs)
    833       grads = outputs**2
    834       damping = 0.  # This test is only valid without damping.
    835       block.instantiate_factors(([grads],), damping)
    836 
    837       sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
    838       sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
    839       sess.run(block._input_factor.make_inverse_update_ops())
    840       sess.run(block._output_factor.make_inverse_update_ops())
    841 
    842       v_flat = np.arange(16, dtype=np.float32)
    843       vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
    844       output = block.multiply_inverse(vector)
    845       output_flat = sess.run(utils.tensors_to_column(output)).ravel()
    846 
    847       full = sess.run(block.full_fisher_block())
    848       explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
    849 
    850       self.assertAllClose(output_flat, explicit)
    851 
    852 
    853 class FullyConnectedSeriesFBTest(test.TestCase):
    854 
    855   def testFullyConnectedSeriesFBInit(self):
    856     with ops.Graph().as_default():
    857       random_seed.set_random_seed(200)
    858       inputs = array_ops.constant([1., 2.])
    859       outputs = array_ops.constant([3., 4.])
    860       block = fb.FullyConnectedSeriesFB(
    861           lc.LayerCollection(), inputs=[inputs], outputs=[outputs])
    862       self.assertAllEqual([outputs], block.tensors_to_compute_grads())
    863 
    864   def testInstantiateFactorsHasBias(self):
    865     with ops.Graph().as_default():
    866       random_seed.set_random_seed(200)
    867       inputs = array_ops.constant([[1., 2.], [3., 4.]])
    868       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    869       block = fb.FullyConnectedSeriesFB(
    870           lc.LayerCollection(),
    871           inputs=[inputs],
    872           outputs=[outputs],
    873           has_bias=True)
    874       grads = outputs**2
    875       block.instantiate_factors(((grads,),), 0.5)
    876 
    877   def testInstantiateFactorsNoBias(self):
    878     with ops.Graph().as_default():
    879       random_seed.set_random_seed(200)
    880       inputs = array_ops.constant([[1., 2.], [3., 4.]])
    881       outputs = array_ops.constant([[3., 4.], [5., 6.]])
    882       block = fb.FullyConnectedSeriesFB(
    883           lc.LayerCollection(),
    884           inputs=[inputs],
    885           outputs=[outputs],
    886           has_bias=False)
    887       grads = outputs**2
    888       block.instantiate_factors(((grads,),), 0.5)
    889 
    890 
    891 def as_tensors(tensor_or_tuple):
    892   """Converts a potentially nested tuple of np.array to Tensors."""
    893   if isinstance(tensor_or_tuple, (tuple, list)):
    894     return tuple(as_tensors(t) for t in tensor_or_tuple)
    895   return ops.convert_to_tensor(tensor_or_tuple)
    896 
    897 
    898 if __name__ == '__main__':
    899   test.main()
    900