1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.kfac.fisher_blocks.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb 24 from tensorflow.contrib.kfac.python.ops import layer_collection as lc 25 from tensorflow.contrib.kfac.python.ops import utils 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import random_seed 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import linalg_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import random_ops 32 from tensorflow.python.ops import state_ops 33 from tensorflow.python.ops import variables as tf_variables 34 from tensorflow.python.platform import test 35 36 37 def _make_psd(dim): 38 """Constructs a PSD matrix of the given dimension.""" 39 mat = np.ones((dim, dim), dtype=np.float32) 40 mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) 41 return array_ops.constant(mat) 42 43 44 class UtilsTest(test.TestCase): 45 46 def testComputePiTracenorm(self): 47 with ops.Graph().as_default(), self.test_session() as sess: 48 random_seed.set_random_seed(200) 49 left_factor = array_ops.diag([1., 2., 0., 1.]) 50 right_factor = array_ops.ones([2., 2.]) 51 52 # pi is the sqrt of the left trace norm divided by the right trace norm 53 pi = fb.compute_pi_tracenorm(left_factor, right_factor) 54 55 pi_val = sess.run(pi) 56 self.assertEqual(1., pi_val) 57 58 59 class FullFBTest(test.TestCase): 60 61 def testFullFBInitSingleTensor(self): 62 with ops.Graph().as_default(): 63 random_seed.set_random_seed(200) 64 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 65 block = fb.FullFB(lc.LayerCollection(), params) 66 block.register_additional_minibatch(32) 67 68 self.assertAllEqual(params, block.tensors_to_compute_grads()) 69 70 def testFullFBInitTensorTuple(self): 71 with ops.Graph().as_default(): 72 random_seed.set_random_seed(200) 73 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 74 block = fb.FullFB(lc.LayerCollection(), params) 75 block.register_additional_minibatch(32) 76 77 self.assertAllEqual(params, block.tensors_to_compute_grads()) 78 79 def testInstantiateFactors(self): 80 with ops.Graph().as_default(): 81 random_seed.set_random_seed(200) 82 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 83 block = fb.FullFB(lc.LayerCollection(), params) 84 block.register_additional_minibatch(32) 85 86 grads = (params[0]**2, math_ops.sqrt(params[1])) 87 block.instantiate_factors(grads, 0.5) 88 89 def testMultiplyInverseTuple(self): 90 with ops.Graph().as_default(), self.test_session() as sess: 91 random_seed.set_random_seed(200) 92 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 93 block = fb.FullFB(lc.LayerCollection(), params) 94 block.register_additional_minibatch(32) 95 grads = (params[0]**2, math_ops.sqrt(params[1])) 96 block.instantiate_factors((grads,), 0.5) 97 98 # Make sure our inverse is something other than the identity. 99 sess.run(tf_variables.global_variables_initializer()) 100 sess.run(block._factor.make_inverse_update_ops()) 101 102 vector = array_ops.ones(3,) * 2 103 output = block.multiply_inverse(vector) 104 105 self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) 106 107 def testMultiplyInverseNotTuple(self): 108 with ops.Graph().as_default(), self.test_session() as sess: 109 random_seed.set_random_seed(200) 110 params = array_ops.constant([[1.], [2.]]) 111 block = fb.FullFB(lc.LayerCollection(), params) 112 block.register_additional_minibatch(32) 113 grads = params**2 114 block.instantiate_factors((grads,), 0.5) 115 116 # Make sure our inverse is something other than the identity. 117 sess.run(tf_variables.global_variables_initializer()) 118 sess.run(block._factor.make_inverse_update_ops()) 119 120 vector = array_ops.ones(2,) * 2 121 output = block.multiply_inverse(vector) 122 123 self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) 124 125 def testMultiplyInverseAgainstExplicit(self): 126 with ops.Graph().as_default(), self.test_session() as sess: 127 random_seed.set_random_seed(200) 128 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 129 block = fb.FullFB(lc.LayerCollection(), params) 130 block.register_additional_minibatch(32) 131 grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) 132 damping = 0.5 133 block.instantiate_factors((grads,), damping) 134 135 # Make sure our inverse is something other than the identity. 136 sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) 137 sess.run(block._factor.make_inverse_update_ops()) 138 139 v_flat = np.array([4., 5., 6.], dtype=np.float32) 140 vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) 141 output = block.multiply_inverse(vector) 142 output_flat = sess.run(utils.tensors_to_column(output)).ravel() 143 144 full = sess.run(block.full_fisher_block()) 145 explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) 146 147 self.assertAllClose(output_flat, explicit) 148 149 150 class NaiveDiagonalFBTest(test.TestCase): 151 152 def testNaiveDiagonalFBInitSingleTensor(self): 153 with ops.Graph().as_default(): 154 random_seed.set_random_seed(200) 155 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 156 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 157 block.register_additional_minibatch(32) 158 159 self.assertAllEqual(params, block.tensors_to_compute_grads()) 160 161 def testNaiveDiagonalFBInitTensorTuple(self): 162 with ops.Graph().as_default(): 163 random_seed.set_random_seed(200) 164 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 165 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 166 block.register_additional_minibatch(32) 167 168 self.assertAllEqual(params, block.tensors_to_compute_grads()) 169 170 def testInstantiateFactors(self): 171 with ops.Graph().as_default(): 172 random_seed.set_random_seed(200) 173 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 174 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 175 block.register_additional_minibatch(32) 176 177 grads = (params[0]**2, math_ops.sqrt(params[1])) 178 block.instantiate_factors(grads, 0.5) 179 180 def testMultiplyInverseTuple(self): 181 with ops.Graph().as_default(), self.test_session() as sess: 182 random_seed.set_random_seed(200) 183 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 184 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 185 block.register_additional_minibatch(32) 186 grads = (params[0]**2, math_ops.sqrt(params[1])) 187 block.instantiate_factors((grads,), 0.5) 188 189 # Make sure our inverse is something other than the identity. 190 sess.run(tf_variables.global_variables_initializer()) 191 sess.run(block._factor.make_inverse_update_ops()) 192 193 vector = array_ops.ones(3,) * 2 194 output = block.multiply_inverse(vector) 195 196 self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) 197 198 def testMultiplyInverseNotTuple(self): 199 with ops.Graph().as_default(), self.test_session() as sess: 200 random_seed.set_random_seed(200) 201 params = array_ops.constant([[1.], [2.]]) 202 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 203 block.register_additional_minibatch(32) 204 grads = params**2 205 block.instantiate_factors((grads,), 0.5) 206 207 # Make sure our inverse is something other than the identity. 208 sess.run(tf_variables.global_variables_initializer()) 209 sess.run(block._factor.make_inverse_update_ops()) 210 vector = array_ops.ones(2,) * 2 211 output = block.multiply_inverse(vector) 212 213 self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) 214 215 def testMultiplyInverseAgainstExplicit(self): 216 with ops.Graph().as_default(), self.test_session() as sess: 217 random_seed.set_random_seed(200) 218 params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) 219 block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) 220 block.register_additional_minibatch(32) 221 grads = (params[0]**2, math_ops.sqrt(params[1])) 222 damping = 0.5 223 block.instantiate_factors((grads,), damping) 224 225 cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) 226 sess.run(state_ops.assign(block._factor._cov, cov)) 227 sess.run(block._factor.make_inverse_update_ops()) 228 229 v_flat = np.array([4., 5., 6.], dtype=np.float32) 230 vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) 231 output = block.multiply_inverse(vector) 232 output_flat = sess.run(utils.tensors_to_column(output)).ravel() 233 234 full = sess.run(block.full_fisher_block()) 235 explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) 236 237 self.assertAllClose(output_flat, explicit) 238 239 240 class FullyConnectedDiagonalFBTest(test.TestCase): 241 242 def setUp(self): 243 super(FullyConnectedDiagonalFBTest, self).setUp() 244 245 self.batch_size = 4 246 self.input_size = 6 247 self.output_size = 3 248 249 self.inputs = np.random.randn(self.batch_size, self.input_size).astype( 250 np.float32) 251 self.outputs = np.zeros([self.batch_size, self.output_size]).astype( 252 np.float32) 253 self.output_grads = np.random.randn(self.batch_size, 254 self.output_size).astype(np.float32) 255 self.w = np.random.randn(self.input_size, self.output_size).astype( 256 np.float32) 257 self.b = np.random.randn(self.output_size).astype(np.float32) 258 259 def fisherApprox(self, has_bias=False): 260 """Fisher approximation using default inputs.""" 261 if has_bias: 262 inputs = np.concatenate( 263 [self.inputs, np.ones([self.batch_size, 1])], axis=1) 264 else: 265 inputs = self.inputs 266 return self.buildDiagonalFisherApproximation(inputs, self.output_grads) 267 268 def buildDiagonalFisherApproximation(self, inputs, output_grads): 269 """Builds explicit diagonal Fisher approximation. 270 271 Fisher's diagonal is (d loss / d w)'s elements squared for 272 d/dw = E[outer(input, output_grad)] 273 274 where the expectation is taken over examples. 275 276 Args: 277 inputs: np.array of shape [batch_size, input_size]. 278 output_grads: np.array of shape [batch_size, output_size]. 279 280 Returns: 281 Diagonal np.array of shape [num_params, num_params] for num_params = 282 input_size * output_size. 283 """ 284 batch_size = inputs.shape[0] 285 assert output_grads.shape[0] == batch_size 286 input_size = inputs.shape[1] 287 output_size = output_grads.shape[1] 288 fisher_diag = np.zeros((input_size, output_size)) 289 for i in range(batch_size): 290 fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) 291 return np.diag(fisher_diag.flatten()) / batch_size 292 293 def testMultiply(self): 294 result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], 295 [self.output_grads]) 296 297 # Construct Fisher-vector product. 298 expected_result = self.fisherApprox().dot(self.w.flatten()) 299 expected_result = expected_result.reshape( 300 [self.input_size, self.output_size]) 301 302 self.assertAllClose(expected_result, result) 303 304 def testMultiplyInverse(self): 305 _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], 306 [self.output_grads]) 307 308 # Construct inverse Fisher-vector product. 309 expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) 310 expected_result = expected_result.reshape( 311 [self.input_size, self.output_size]) 312 313 self.assertAllClose(expected_result, result) 314 315 def testRegisterAdditionalMinibatch(self): 316 """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" 317 multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( 318 self.w, [self.inputs], [self.outputs], [self.output_grads]) 319 multiply_result_small, multiply_inverse_result_small = ( 320 self.runFisherBlockOps(self.w, np.split(self.inputs, 2), 321 np.split(self.outputs, 2), 322 np.split(self.output_grads, 2))) 323 324 self.assertAllClose(multiply_result_big, multiply_result_small) 325 self.assertAllClose(multiply_inverse_result_big, 326 multiply_inverse_result_small) 327 328 def testMultiplyHasBias(self): 329 result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], 330 [self.outputs], [self.output_grads]) 331 expected_result = self.fisherApprox(True).dot( 332 np.concatenate([self.w.flatten(), self.b.flatten()])) 333 expected_result = expected_result.reshape( 334 [self.input_size + 1, self.output_size]) 335 expected_result = (expected_result[:-1], expected_result[-1]) 336 337 self.assertEqual(len(result), 2) 338 self.assertAllClose(expected_result[0], result[0]) 339 self.assertAllClose(expected_result[1], result[1]) 340 341 def runFisherBlockOps(self, params, inputs, outputs, output_grads): 342 """Run Ops guaranteed by FisherBlock interface. 343 344 Args: 345 params: Tensor or 2-tuple of Tensors. Represents weights or weights and 346 bias of this layer. 347 inputs: list of Tensors of shape [batch_size, input_size]. Inputs to 348 layer. 349 outputs: list of Tensors of shape [batch_size, output_size]. 350 Preactivations produced by layer. 351 output_grads: list of Tensors of shape [batch_size, output_size]. 352 Gradient of loss with respect to 'outputs'. 353 354 Returns: 355 multiply_result: Result of FisherBlock.multiply(params) 356 multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) 357 """ 358 with ops.Graph().as_default(), self.test_session() as sess: 359 inputs = as_tensors(inputs) 360 outputs = as_tensors(outputs) 361 output_grads = as_tensors(output_grads) 362 params = as_tensors(params) 363 364 block = fb.FullyConnectedDiagonalFB( 365 lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) 366 for (i, o) in zip(inputs, outputs): 367 block.register_additional_minibatch(i, o) 368 369 block.instantiate_factors((output_grads,), damping=0.0) 370 371 sess.run(tf_variables.global_variables_initializer()) 372 sess.run(block._factor.make_covariance_update_op(0.0)) 373 multiply_result = sess.run(block.multiply(params)) 374 multiply_inverse_result = sess.run(block.multiply_inverse(params)) 375 376 return multiply_result, multiply_inverse_result 377 378 379 class EmbeddingKFACFBTest(test.TestCase): 380 381 def testInstantiateFactors(self): 382 with ops.Graph().as_default(): 383 random_seed.set_random_seed(200) 384 385 # Create a Fisher Block. 386 vocab_size = 5 387 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) 388 389 # Add some examples. 390 inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) 391 outputs = array_ops.constant([[0.], [1.], [2.]]) 392 block.register_additional_minibatch(inputs, outputs) 393 394 # Instantiate factor's variables. Ensure it doesn't fail. 395 grads = outputs**2. 396 damping = array_ops.constant(0.) 397 block.instantiate_factors(([grads],), damping) 398 399 def testMultiplyInverse(self): 400 with ops.Graph().as_default(), self.test_session() as sess: 401 random_seed.set_random_seed(200) 402 403 # Create a Fisher Block. 404 vocab_size = 5 405 block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) 406 407 # Add some examples. 408 inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) 409 outputs = array_ops.constant([[0.], [1.], [2.]]) 410 block.register_additional_minibatch(inputs, outputs) 411 412 # Instantiate factor's variables. Ensure it doesn't fail. 413 grads = outputs**2. 414 damping = array_ops.constant(0.) 415 block.instantiate_factors(([grads],), damping) 416 417 # Create a sparse update. 418 indices = array_ops.constant([1, 3, 4]) 419 values = array_ops.constant([[1.], [1.], [1.]]) 420 sparse_vector = ops.IndexedSlices( 421 values, indices, dense_shape=[vocab_size, 1]) 422 dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) 423 424 # Compare Fisher-vector product against explicit result. 425 result = block.multiply_inverse(sparse_vector) 426 expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), 427 dense_vector) 428 429 sess.run(tf_variables.global_variables_initializer()) 430 self.assertAlmostEqual( 431 sess.run(expected_result[1]), sess.run(result.values[0])) 432 self.assertAlmostEqual( 433 sess.run(expected_result[3]), sess.run(result.values[1])) 434 self.assertAlmostEqual( 435 sess.run(expected_result[4]), sess.run(result.values[2])) 436 437 438 class FullyConnectedKFACBasicFBTest(test.TestCase): 439 440 def testFullyConnectedKFACBasicFBInit(self): 441 with ops.Graph().as_default(): 442 random_seed.set_random_seed(200) 443 inputs = array_ops.constant([1., 2.]) 444 outputs = array_ops.constant([3., 4.]) 445 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) 446 block.register_additional_minibatch(inputs, outputs) 447 448 self.assertAllEqual([outputs], block.tensors_to_compute_grads()) 449 450 def testInstantiateFactorsHasBias(self): 451 with ops.Graph().as_default(): 452 random_seed.set_random_seed(200) 453 inputs = array_ops.constant([[1., 2.], [3., 4.]]) 454 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 455 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) 456 block.register_additional_minibatch(inputs, outputs) 457 458 grads = outputs**2 459 block.instantiate_factors(([grads],), 0.5) 460 461 def testInstantiateFactorsNoBias(self): 462 with ops.Graph().as_default(): 463 random_seed.set_random_seed(200) 464 inputs = array_ops.constant([[1., 2.], [3., 4.]]) 465 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 466 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) 467 block.register_additional_minibatch(inputs, outputs) 468 469 grads = outputs**2 470 block.instantiate_factors(([grads],), 0.5) 471 472 def testMultiplyInverseTuple(self): 473 with ops.Graph().as_default(), self.test_session() as sess: 474 random_seed.set_random_seed(200) 475 inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) 476 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 477 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) 478 block.register_additional_minibatch(inputs, outputs) 479 grads = outputs**2 480 block.instantiate_factors(([grads],), 0.5) 481 482 # Make sure our inverse is something other than the identity. 483 sess.run(tf_variables.global_variables_initializer()) 484 sess.run(block._input_factor.make_inverse_update_ops()) 485 sess.run(block._output_factor.make_inverse_update_ops()) 486 487 vector = ( 488 np.arange(2, 6).reshape(2, 2).astype(np.float32), # 489 np.arange(1, 3).reshape(2, 1).astype(np.float32)) 490 output = block.multiply_inverse((array_ops.constant(vector[0]), 491 array_ops.constant(vector[1]))) 492 493 output = sess.run(output) 494 self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], 495 output[0]) 496 self.assertAllClose([0.343146, 0.686291], output[1]) 497 498 def testMultiplyInverseNotTuple(self): 499 with ops.Graph().as_default(), self.test_session() as sess: 500 random_seed.set_random_seed(200) 501 inputs = array_ops.constant([[1., 2.], [3., 4.]]) 502 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 503 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) 504 block.register_additional_minibatch(inputs, outputs) 505 grads = outputs**2 506 block.instantiate_factors(([grads],), 0.5) 507 508 # Make sure our inverse is something other than the identity. 509 sess.run(tf_variables.global_variables_initializer()) 510 sess.run(block._input_factor.make_inverse_update_ops()) 511 sess.run(block._output_factor.make_inverse_update_ops()) 512 513 vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) 514 output = block.multiply_inverse(array_ops.constant(vector)) 515 516 self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], 517 sess.run(output)) 518 519 def testMultiplyInverseAgainstExplicit(self): 520 with ops.Graph().as_default(), self.test_session() as sess: 521 random_seed.set_random_seed(200) 522 input_dim, output_dim = 3, 2 523 inputs = array_ops.zeros([32, input_dim]) 524 outputs = array_ops.zeros([32, output_dim]) 525 params = array_ops.zeros([input_dim, output_dim]) 526 block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) 527 block.register_additional_minibatch(inputs, outputs) 528 grads = outputs**2 529 damping = 0. # This test is only valid without damping. 530 block.instantiate_factors(([grads],), damping) 531 532 sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) 533 sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) 534 sess.run(block._input_factor.make_inverse_update_ops()) 535 sess.run(block._output_factor.make_inverse_update_ops()) 536 537 v_flat = np.arange(6, dtype=np.float32) 538 vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) 539 output = block.multiply_inverse(vector) 540 output_flat = sess.run(utils.tensors_to_column(output)).ravel() 541 542 full = sess.run(block.full_fisher_block()) 543 explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) 544 545 self.assertAllClose(output_flat, explicit) 546 547 548 class ConvDiagonalFBTest(test.TestCase): 549 550 def setUp(self): 551 super(ConvDiagonalFBTest, self).setUp() 552 553 self.batch_size = 2 554 self.height = 8 555 self.width = 4 556 self.input_channels = 6 557 self.output_channels = 3 558 self.kernel_size = 1 559 560 self.inputs = np.random.randn(self.batch_size, self.height, self.width, 561 self.input_channels).astype(np.float32) 562 self.outputs = np.zeros( 563 [self.batch_size, self.height, self.width, 564 self.output_channels]).astype(np.float32) 565 self.output_grads = np.random.randn( 566 self.batch_size, self.height, self.width, self.output_channels).astype( 567 np.float32) 568 self.w = np.random.randn(self.kernel_size, self.kernel_size, 569 self.input_channels, self.output_channels).astype( 570 np.float32) 571 self.b = np.random.randn(self.output_channels).astype(np.float32) 572 573 def fisherApprox(self, has_bias=False): 574 """Fisher approximation using default inputs.""" 575 if has_bias: 576 inputs = np.concatenate( 577 [self.inputs, 578 np.ones([self.batch_size, self.height, self.width, 1])], 579 axis=-1) 580 else: 581 inputs = self.inputs 582 return self.buildDiagonalFisherApproximation(inputs, self.output_grads, 583 self.kernel_size) 584 585 def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): 586 r"""Builds explicit diagonal Fisher approximation. 587 588 Fisher's diagonal is (d loss / d w)'s elements squared for 589 d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] 590 591 where the expectation is taken over examples and the sum over (x, y) 592 locations upon which the convolution is applied. 593 594 Args: 595 inputs: np.array of shape [batch_size, height, width, input_channels]. 596 output_grads: np.array of shape [batch_size, height, width, 597 output_channels]. 598 kernel_size: int. height and width of kernel. 599 600 Returns: 601 Diagonal np.array of shape [num_params, num_params] for num_params = 602 kernel_size^2 * input_channels * output_channels. 603 """ 604 batch_size, height, width, input_channels = inputs.shape 605 assert output_grads.shape[0] == batch_size 606 assert output_grads.shape[1] == height 607 assert output_grads.shape[2] == width 608 output_channels = output_grads.shape[3] 609 610 # If kernel_size == 1, then we don't need to worry about capturing context 611 # around the pixel upon which a convolution is applied. This makes testing 612 # easier. 613 assert kernel_size == 1, "kernel_size != 1 isn't supported." 614 num_locations = height * width 615 inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) 616 output_grads = np.reshape(output_grads, 617 [batch_size, num_locations, output_channels]) 618 619 fisher_diag = np.zeros((input_channels, output_channels)) 620 for i in range(batch_size): 621 # Each example's approximation is a square(sum-of-outer-products). 622 example_fisher_diag = np.zeros((input_channels, output_channels)) 623 for j in range(num_locations): 624 example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) 625 fisher_diag += np.square(example_fisher_diag) 626 627 # Normalize by batch_size (not num_locations). 628 return np.diag(fisher_diag.flatten()) / batch_size 629 630 def testMultiply(self): 631 result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], 632 [self.output_grads]) 633 634 # Construct Fisher-vector product. 635 expected_result = self.fisherApprox().dot(self.w.flatten()) 636 expected_result = expected_result.reshape([ 637 self.kernel_size, self.kernel_size, self.input_channels, 638 self.output_channels 639 ]) 640 641 self.assertAllClose(expected_result, result) 642 643 def testMultiplyInverse(self): 644 _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], 645 [self.output_grads]) 646 647 # Construct inverse Fisher-vector product. 648 expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) 649 expected_result = expected_result.reshape([ 650 self.kernel_size, self.kernel_size, self.input_channels, 651 self.output_channels 652 ]) 653 654 self.assertAllClose(expected_result, result, atol=1e-3) 655 656 def testRegisterAdditionalMinibatch(self): 657 """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" 658 multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( 659 self.w, [self.inputs], [self.outputs], [self.output_grads]) 660 multiply_result_small, multiply_inverse_result_small = ( 661 self.runFisherBlockOps(self.w, np.split(self.inputs, 2), 662 np.split(self.outputs, 2), 663 np.split(self.output_grads, 2))) 664 665 self.assertAllClose(multiply_result_big, multiply_result_small) 666 self.assertAllClose(multiply_inverse_result_big, 667 multiply_inverse_result_small) 668 669 def testMultiplyHasBias(self): 670 result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], 671 [self.outputs], [self.output_grads]) 672 # Clone 'b' along 'input_channels' dimension. 673 b_filter = np.tile( 674 np.reshape(self.b, [1, 1, 1, self.output_channels]), 675 [self.kernel_size, self.kernel_size, 1, 1]) 676 params = np.concatenate([self.w, b_filter], axis=2) 677 expected_result = self.fisherApprox(True).dot(params.flatten()) 678 679 # Extract 'b' from concatenated parameters. 680 expected_result = expected_result.reshape([ 681 self.kernel_size, self.kernel_size, self.input_channels + 1, 682 self.output_channels 683 ]) 684 expected_result = (expected_result[:, :, 0:-1, :], 685 np.reshape(expected_result[:, :, -1, :], 686 [self.output_channels])) 687 688 self.assertEqual(len(result), 2) 689 self.assertAllClose(expected_result[0], result[0]) 690 self.assertAllClose(expected_result[1], result[1]) 691 692 def runFisherBlockOps(self, params, inputs, outputs, output_grads): 693 """Run Ops guaranteed by FisherBlock interface. 694 695 Args: 696 params: Tensor or 2-tuple of Tensors. Represents weights or weights and 697 bias of this layer. 698 inputs: list of Tensors of shape [batch_size, input_size]. Inputs to 699 layer. 700 outputs: list of Tensors of shape [batch_size, output_size]. 701 Preactivations produced by layer. 702 output_grads: list of Tensors of shape [batch_size, output_size]. 703 Gradient of loss with respect to 'outputs'. 704 705 Returns: 706 multiply_result: Result of FisherBlock.multiply(params) 707 multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) 708 """ 709 with ops.Graph().as_default(), self.test_session() as sess: 710 inputs = as_tensors(inputs) 711 outputs = as_tensors(outputs) 712 output_grads = as_tensors(output_grads) 713 params = as_tensors(params) 714 715 block = fb.ConvDiagonalFB( 716 lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') 717 for (i, o) in zip(inputs, outputs): 718 block.register_additional_minibatch(i, o) 719 720 block.instantiate_factors((output_grads,), damping=0.0) 721 722 sess.run(tf_variables.global_variables_initializer()) 723 sess.run(block._factor.make_covariance_update_op(0.0)) 724 multiply_result = sess.run(block.multiply(params)) 725 multiply_inverse_result = sess.run(block.multiply_inverse(params)) 726 727 return multiply_result, multiply_inverse_result 728 729 730 class ConvKFCBasicFBTest(test.TestCase): 731 732 def _testConvKFCBasicFBInitParams(self, params): 733 with ops.Graph().as_default(): 734 random_seed.set_random_seed(200) 735 if isinstance(params, (list, tuple)): 736 params = [array_ops.constant(param) for param in params] 737 else: 738 params = array_ops.constant(params) 739 inputs = random_ops.random_normal((2, 2, 2)) 740 outputs = random_ops.random_normal((2, 2, 2)) 741 block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') 742 block.register_additional_minibatch(inputs, outputs) 743 744 self.assertAllEqual([outputs], block.tensors_to_compute_grads()) 745 746 def testConvKFCBasicFBInitParamsParamsTuple(self): 747 self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) 748 749 def testConvKFCBasicFBInitParamsParamsSingle(self): 750 self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) 751 752 def testMultiplyInverseTuple(self): 753 with ops.Graph().as_default(), self.test_session() as sess: 754 random_seed.set_random_seed(200) 755 params = random_ops.random_normal((2, 2, 2, 2)) 756 inputs = random_ops.random_normal((2, 2, 2, 2)) 757 outputs = random_ops.random_normal((2, 2, 2, 2)) 758 block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), 759 'SAME') 760 block.register_additional_minibatch(inputs, outputs) 761 grads = outputs**2 762 block.instantiate_factors(([grads],), 0.5) 763 764 # Make sure our inverse is something other than the identity. 765 sess.run(tf_variables.global_variables_initializer()) 766 sess.run(block._input_factor.make_inverse_update_ops()) 767 sess.run(block._output_factor.make_inverse_update_ops()) 768 769 vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), 770 np.arange(2, 4).reshape(2, 1).astype(np.float32)) 771 output = block.multiply_inverse((array_ops.constant(vector[0]), 772 array_ops.constant(vector[1]))) 773 774 output = sess.run(output) 775 self.assertAllClose([0.136455, 0.27291], output[0][0]) 776 self.assertAllClose([0.27291, 0.409365], output[1]) 777 778 def testMultiplyInverseNotTuple(self): 779 with ops.Graph().as_default(), self.test_session() as sess: 780 random_seed.set_random_seed(200) 781 params = random_ops.random_normal((2, 2, 2, 2)) 782 inputs = random_ops.random_normal((2, 2, 2, 2)) 783 outputs = random_ops.random_normal((2, 2, 2, 2)) 784 block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), 785 'SAME') 786 block.register_additional_minibatch(inputs, outputs) 787 self.assertFalse(block._has_bias) 788 grads = outputs**2 789 block.instantiate_factors(([grads],), 0.5) 790 791 # Make sure our inverse is something other than the identity. 792 sess.run(tf_variables.global_variables_initializer()) 793 sess.run(block._input_factor.make_inverse_update_ops()) 794 sess.run(block._output_factor.make_inverse_update_ops()) 795 796 vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) 797 output = block.multiply_inverse(array_ops.constant(vector)) 798 799 self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) 800 801 def testMultiplyInverseNotTupleWithBias(self): 802 with ops.Graph().as_default(), self.test_session() as sess: 803 random_seed.set_random_seed(200) 804 params = [random_ops.random_normal((2, 2, 2, 2))] 805 inputs = random_ops.random_normal((2, 2, 2, 2)) 806 outputs = random_ops.random_normal((2, 2, 2, 2)) 807 block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), 808 'SAME') 809 block.register_additional_minibatch(inputs, outputs) 810 self.assertTrue(block._has_bias) 811 grads = outputs**2 812 block.instantiate_factors(([grads],), 0.5) 813 814 # Make sure our inverse is something other than the identity. 815 sess.run(tf_variables.global_variables_initializer()) 816 sess.run(block._input_factor.make_inverse_update_ops()) 817 sess.run(block._output_factor.make_inverse_update_ops()) 818 819 vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) 820 output = block.multiply_inverse(array_ops.constant(vector)) 821 822 self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) 823 824 def testMultiplyInverseAgainstExplicit(self): 825 with ops.Graph().as_default(), self.test_session() as sess: 826 random_seed.set_random_seed(200) 827 params = array_ops.zeros((2, 2, 2, 2)) 828 inputs = array_ops.zeros((2, 2, 2, 2)) 829 outputs = array_ops.zeros((2, 2, 2, 2)) 830 block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), 831 'SAME') 832 block.register_additional_minibatch(inputs, outputs) 833 grads = outputs**2 834 damping = 0. # This test is only valid without damping. 835 block.instantiate_factors(([grads],), damping) 836 837 sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) 838 sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) 839 sess.run(block._input_factor.make_inverse_update_ops()) 840 sess.run(block._output_factor.make_inverse_update_ops()) 841 842 v_flat = np.arange(16, dtype=np.float32) 843 vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) 844 output = block.multiply_inverse(vector) 845 output_flat = sess.run(utils.tensors_to_column(output)).ravel() 846 847 full = sess.run(block.full_fisher_block()) 848 explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) 849 850 self.assertAllClose(output_flat, explicit) 851 852 853 class FullyConnectedSeriesFBTest(test.TestCase): 854 855 def testFullyConnectedSeriesFBInit(self): 856 with ops.Graph().as_default(): 857 random_seed.set_random_seed(200) 858 inputs = array_ops.constant([1., 2.]) 859 outputs = array_ops.constant([3., 4.]) 860 block = fb.FullyConnectedSeriesFB( 861 lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) 862 self.assertAllEqual([outputs], block.tensors_to_compute_grads()) 863 864 def testInstantiateFactorsHasBias(self): 865 with ops.Graph().as_default(): 866 random_seed.set_random_seed(200) 867 inputs = array_ops.constant([[1., 2.], [3., 4.]]) 868 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 869 block = fb.FullyConnectedSeriesFB( 870 lc.LayerCollection(), 871 inputs=[inputs], 872 outputs=[outputs], 873 has_bias=True) 874 grads = outputs**2 875 block.instantiate_factors(((grads,),), 0.5) 876 877 def testInstantiateFactorsNoBias(self): 878 with ops.Graph().as_default(): 879 random_seed.set_random_seed(200) 880 inputs = array_ops.constant([[1., 2.], [3., 4.]]) 881 outputs = array_ops.constant([[3., 4.], [5., 6.]]) 882 block = fb.FullyConnectedSeriesFB( 883 lc.LayerCollection(), 884 inputs=[inputs], 885 outputs=[outputs], 886 has_bias=False) 887 grads = outputs**2 888 block.instantiate_factors(((grads,),), 0.5) 889 890 891 def as_tensors(tensor_or_tuple): 892 """Converts a potentially nested tuple of np.array to Tensors.""" 893 if isinstance(tensor_or_tuple, (tuple, list)): 894 return tuple(as_tensors(t) for t in tensor_or_tuple) 895 return ops.convert_to_tensor(tensor_or_tuple) 896 897 898 if __name__ == '__main__': 899 test.main() 900