Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for fused_batch_norm related functionality in tensorflow.ops.nn."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops import nn_grad
     30 from tensorflow.python.ops import nn_impl
     31 from tensorflow.python.platform import test
     32 
     33 
     34 class BatchNormalizationTest(test.TestCase):
     35 
     36   def _batch_norm(self, x, mean, var, offset, scale, epsilon):
     37     # We compute the batch norm manually in this function because
     38     # nn_impl.batch_normalization does not support float16 yet.
     39     # TODO(reedwm): Add float16 support to nn_impl.batch_normalization.
     40     inv = math_ops.rsqrt(var + epsilon) * scale
     41     y = math_ops.cast(x, scale.dtype) * inv + (offset - mean * inv)
     42     return math_ops.cast(y, x.dtype)
     43 
     44   def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
     45     if data_format not in ['NHWC', 'NCHW']:
     46       raise ValueError('data_format must be NCHW or NHWC, '
     47                        'got %s.' % data_format)
     48     if data_format == 'NCHW':
     49       x = array_ops.transpose(x, [0, 2, 3, 1])
     50     y = self._batch_norm(x, mean, var, offset, scale, epsilon)
     51     if data_format == 'NCHW':
     52       y = array_ops.transpose(y, [0, 3, 1, 2])
     53     return y.eval()
     54 
     55   def _test_inference(self,
     56                       x_shape,
     57                       x_dtype,
     58                       scale_shape,
     59                       scale_dtype,
     60                       use_gpu=True,
     61                       data_format='NHWC'):
     62     np.random.seed(1)
     63     x_val = np.random.random_sample(x_shape).astype(x_dtype)
     64     scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
     65     offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
     66     mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
     67     var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
     68 
     69     with self.test_session(use_gpu=use_gpu) as sess:
     70       x = constant_op.constant(x_val, name='x')
     71       scale = constant_op.constant(scale_val, name='scale')
     72       offset = constant_op.constant(offset_val, name='offset')
     73       mean = constant_op.constant(mean_val, name='mean')
     74       var = constant_op.constant(var_val, name='variance')
     75       epsilon = 0.001
     76       y, _, _ = nn_impl.fused_batch_norm(
     77           x,
     78           scale,
     79           offset,
     80           mean=mean,
     81           variance=var,
     82           epsilon=epsilon,
     83           data_format=data_format,
     84           is_training=False)
     85       y_val = sess.run(y)
     86       y_ref = self._inference_ref(x, scale, offset, mean, var, epsilon,
     87                                   data_format)
     88     # An atol value of 1e-3 is too small for float16's, because some adjacent
     89     # float16 values that y_val can take are greater than 1e-3 apart, e.g.
     90     # 2.16602 and 2.16797.
     91     atol = 2e-3 if x_dtype == np.float16 else 1e-3
     92     self.assertAllClose(y_ref, y_val, atol=atol)
     93 
     94   def _training_ref(self, x, scale, offset, epsilon, data_format):
     95     if data_format not in ['NHWC', 'NCHW']:
     96       raise ValueError('data_format must be NCHW or NHWC, '
     97                        'got %s.' % data_format)
     98     if data_format == 'NCHW':
     99       x = array_ops.transpose(x, [0, 2, 3, 1])
    100     mean, var = nn_impl.moments(
    101         math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
    102     y = self._batch_norm(x, mean, var, offset, scale, epsilon)
    103     if data_format == 'NCHW':
    104       y = array_ops.transpose(y, [0, 3, 1, 2])
    105     return y.eval(), mean.eval(), var.eval()
    106 
    107   def _test_training(self,
    108                      x_shape,
    109                      x_dtype,
    110                      scale_shape,
    111                      scale_dtype,
    112                      use_gpu=True,
    113                      data_format='NHWC'):
    114     np.random.seed(1)
    115     x_val = np.random.random_sample(x_shape).astype(x_dtype)
    116     scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    117     offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    118     with self.test_session(use_gpu=use_gpu) as sess:
    119       x = constant_op.constant(x_val, name='x')
    120       scale = constant_op.constant(scale_val, name='scale')
    121       offset = constant_op.constant(offset_val, name='offset')
    122       epsilon = 0.001
    123       y, mean, var = nn_impl.fused_batch_norm(
    124           x,
    125           scale,
    126           offset,
    127           epsilon=epsilon,
    128           data_format=data_format,
    129           is_training=True)
    130       y_val, mean_val, var_val = sess.run([y, mean, var])
    131       y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
    132                                                     data_format)
    133     y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
    134     self.assertAllClose(y_ref, y_val, atol=y_atol)
    135     self.assertAllClose(mean_ref, mean_val, atol=1e-3)
    136     # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
    137     # the denominator in the formula to calculate variance, while
    138     # tf.nn.fused_batch_norm has Bessel's correction built in.
    139     sample_size = x_val.size / scale_val.size
    140     var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
    141     self.assertAllClose(var_ref, var_val, atol=1e-3)
    142 
    143   def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
    144     """Computes the gradient error for float16 inputs and/or outputs.
    145 
    146     This returns the same value as gradient_checker.compute_gradient_error. The
    147     difference is that gradient_checker.compute_gradient_error does not
    148     numerically compute the gradients in a numerically stable way for float16
    149     tensors. To fix this, this function requires float32 versions of x and y to
    150     numerically compute the gradients, to compare with the float16 symbolically
    151     computed gradients.
    152 
    153     Args:
    154       x: The input tensor.
    155       x32: A float32 version of x.
    156       x_shape: The shape of x.
    157       y: The output tensor.
    158       y32: A float32 version of y. Must be calculated based on x32, not x.
    159       y_shape: The shape of y.
    160 
    161     Returns:
    162       The maximum error in between the two Jacobians, as in
    163       gradient_checker.compute_gradient_error.
    164     """
    165     x_init_val = np.random.random_sample(x_shape).astype(np.float16)
    166     x32_init_val = x_init_val.astype(np.float32)
    167 
    168     # TODO(reedwm): Do not perform the unnecessary computations in
    169     # compute_gradient, since they double the computation time of this function.
    170     theoretical_grad, _ = gradient_checker.compute_gradient(
    171         x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val)
    172     _, numerical_grad = gradient_checker.compute_gradient(
    173         x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val)
    174 
    175     # If grad is empty, no error.
    176     if theoretical_grad.size == 0 and numerical_grad.size == 0:
    177       return 0
    178     return np.fabs(theoretical_grad - numerical_grad).max()
    179 
    180   def _test_gradient(self,
    181                      x_shape,
    182                      x_dtype,
    183                      scale_shape,
    184                      scale_dtype,
    185                      use_gpu=True,
    186                      data_format='NHWC',
    187                      is_training=True):
    188     np.random.seed(1)
    189     x_val = np.random.random_sample(x_shape).astype(x_dtype)
    190     scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    191     offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    192 
    193     with self.test_session(use_gpu=use_gpu):
    194       x = constant_op.constant(x_val, name='x')
    195       scale = constant_op.constant(scale_val, name='scale')
    196       offset = constant_op.constant(offset_val, name='offset')
    197       if is_training:
    198         pop_mean = None
    199         pop_var = None
    200       else:
    201         pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
    202         pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
    203       y, _, _ = nn_impl.fused_batch_norm(
    204           x,
    205           scale,
    206           offset,
    207           mean=pop_mean,
    208           variance=pop_var,
    209           data_format=data_format,
    210           is_training=is_training)
    211       if x_dtype != np.float16:
    212         err_x = gradient_checker.compute_gradient_error(x, x_shape, y, x_shape)
    213         err_scale = gradient_checker.compute_gradient_error(
    214             scale, scale_shape, y, x_shape)
    215         err_offset = gradient_checker.compute_gradient_error(
    216             offset, scale_shape, y, x_shape)
    217       else:
    218         x32 = constant_op.constant(x_val, name='x32', dtype=dtypes.float32)
    219         y32, _, _ = nn_impl.fused_batch_norm(
    220             x32,
    221             scale,
    222             offset,
    223             mean=pop_mean,
    224             variance=pop_var,
    225             data_format=data_format,
    226             is_training=is_training)
    227         err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
    228                                                      x_shape)
    229         err_scale = self._compute_gradient_error_float16(
    230             scale, scale, scale_shape, y, y32, x_shape)
    231         err_offset = self._compute_gradient_error_float16(
    232             offset, offset, scale_shape, y, y32, x_shape)
    233 
    234     x_err_tolerance = 2e-3 if x_dtype == np.float16 else 1e-3
    235     scale_err_tolerance = 1e-3
    236     self.assertLess(err_x, x_err_tolerance)
    237     self.assertLess(err_scale, scale_err_tolerance)
    238     self.assertLess(err_offset, scale_err_tolerance)
    239 
    240   def _test_grad_grad(self,
    241                       x_shape,
    242                       x_dtype,
    243                       scale_shape,
    244                       scale_dtype,
    245                       use_gpu=True,
    246                       data_format='NHWC',
    247                       is_training=True,
    248                       err_tolerance=1e-3):
    249     np.random.seed(1)
    250     x_val = np.random.random_sample(x_shape).astype(x_dtype)
    251     grad_y_val = np.random.random_sample(x_shape).astype(x_dtype)
    252     scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    253     offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
    254 
    255     with self.test_session(use_gpu=use_gpu) as sess:
    256       x = constant_op.constant(x_val, name='x')
    257       grad_y = constant_op.constant(grad_y_val, name='grad_y')
    258       scale = constant_op.constant(scale_val, name='scale')
    259       offset = constant_op.constant(offset_val, name='offset')
    260       if is_training:
    261         pop_mean = None
    262         pop_var = None
    263       else:
    264         pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
    265         pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
    266       y, _, _ = nn_impl.fused_batch_norm(
    267           x,
    268           scale,
    269           offset,
    270           mean=pop_mean,
    271           variance=pop_var,
    272           data_format=data_format,
    273           is_training=is_training)
    274       grad_x, grad_scale, grad_offset = gradients_impl.gradients(
    275           y, [x, scale, offset], grad_y)
    276 
    277       if is_training:
    278         epsilon = y.op.get_attr('epsilon')
    279         data_format = y.op.get_attr('data_format')
    280         grad_vals = sess.run([grad_x, grad_scale, grad_offset])
    281         grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean,
    282                                                pop_var, epsilon, data_format)
    283         grad_internal_vals = sess.run(list(grad_internal))
    284         for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals):
    285           self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance)
    286 
    287       if x_dtype != np.float16:
    288         err_grad_grad_y_1 = gradient_checker.compute_gradient_error(
    289             grad_y, x_shape, grad_x, x_shape)
    290         err_grad_grad_y_2 = gradient_checker.compute_gradient_error(
    291             grad_y, x_shape, grad_scale, scale_shape)
    292         err_grad_grad_y_3 = gradient_checker.compute_gradient_error(
    293             grad_y, x_shape, grad_offset, scale_shape)
    294         # In freeze mode, grad_x is not a function of x.
    295         if is_training:
    296           err_grad_x_1 = gradient_checker.compute_gradient_error(
    297               x, x_shape, grad_x, x_shape)
    298         err_grad_x_2 = gradient_checker.compute_gradient_error(
    299             x, x_shape, grad_scale, scale_shape)
    300 
    301         err_grad_scale = gradient_checker.compute_gradient_error(
    302             scale, scale_shape, grad_x, x_shape)
    303       else:
    304         x32 = constant_op.constant(x_val, dtype=dtypes.float32, name='x32')
    305         grad_y32 = constant_op.constant(
    306             grad_y_val, dtype=dtypes.float32, name='grad_y32')
    307         y32, _, _ = nn_impl.fused_batch_norm(
    308             x32,
    309             scale,
    310             offset,
    311             mean=pop_mean,
    312             variance=pop_var,
    313             data_format=data_format,
    314             is_training=is_training)
    315         grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
    316             y32, [x32, scale, offset], grad_y32)
    317         err_grad_grad_y_1 = self._compute_gradient_error_float16(
    318             grad_y, grad_y32, x_shape, grad_x, grad_x32, x_shape)
    319         err_grad_grad_y_2 = self._compute_gradient_error_float16(
    320             grad_y, grad_y32, x_shape, grad_scale, grad_scale32, scale_shape)
    321         err_grad_grad_y_3 = self._compute_gradient_error_float16(
    322             grad_y, grad_y32, x_shape, grad_offset, grad_offset32, scale_shape)
    323         # In freeze mode, grad_x is not a function of x.
    324         if is_training:
    325           err_grad_x_1 = self._compute_gradient_error_float16(
    326               x, x32, x_shape, grad_x, grad_x32, x_shape)
    327         err_grad_x_2 = self._compute_gradient_error_float16(
    328             x, x32, x_shape, grad_scale, grad_scale32, scale_shape)
    329 
    330         err_grad_scale = self._compute_gradient_error_float16(
    331             scale, scale, scale_shape, grad_x, grad_x32, x_shape)
    332 
    333     self.assertLess(err_grad_grad_y_1, err_tolerance)
    334     self.assertLess(err_grad_grad_y_2, err_tolerance)
    335     self.assertLess(err_grad_grad_y_3, err_tolerance)
    336     if is_training:
    337       self.assertLess(err_grad_x_1, err_tolerance)
    338     self.assertLess(err_grad_x_2, err_tolerance)
    339     self.assertLess(err_grad_scale, err_tolerance)
    340 
    341   def testInferenceShape1(self):
    342     x_shape = [1, 1, 6, 1]
    343     for dtype in [np.float16, np.float32]:
    344       if test.is_gpu_available(cuda_only=True):
    345         self._test_inference(
    346             x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
    347         self._test_inference(
    348             x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
    349       self._test_inference(
    350           x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
    351 
    352   def testInferenceShape2(self):
    353     x_shape = [1, 1, 6, 2]
    354     if test.is_gpu_available(cuda_only=True):
    355       for dtype in [np.float16, np.float32]:
    356         self._test_inference(
    357             x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
    358         self._test_inference(
    359             x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
    360 
    361   def testInferenceShape3(self):
    362     x_shape = [1, 2, 1, 6]
    363     if test.is_gpu_available(cuda_only=True):
    364       for dtype in [np.float16, np.float32]:
    365         self._test_inference(
    366             x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
    367 
    368   def testInferenceShape4(self):
    369     x_shape = [27, 131, 127, 6]
    370     for dtype in [np.float16, np.float32]:
    371       if test.is_gpu_available(cuda_only=True):
    372         self._test_inference(
    373             x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
    374         self._test_inference(
    375             x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
    376       self._test_inference(
    377           x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
    378 
    379   def testInferenceShape5(self):
    380     x_shape = [0, 131, 127, 6]
    381     for dtype in [np.float16, np.float32]:
    382       if test.is_gpu_available(cuda_only=True):
    383         self._test_inference(
    384             x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
    385         self._test_inference(
    386             x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
    387       self._test_inference(
    388           x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
    389 
    390   def testTrainingShape1(self):
    391     x_shape = [1, 1, 6, 1]
    392     for dtype in [np.float16, np.float32]:
    393       if test.is_gpu_available(cuda_only=True):
    394         self._test_training(
    395             x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
    396         self._test_training(
    397             x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
    398       self._test_training(
    399           x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
    400 
    401   def testTrainingShape2(self):
    402     x_shape = [1, 1, 6, 2]
    403     for dtype in [np.float16, np.float32]:
    404       if test.is_gpu_available(cuda_only=True):
    405         self._test_training(
    406             x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
    407       self._test_training(
    408           x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
    409 
    410   def testTrainingShape3(self):
    411     x_shape = [1, 2, 1, 6]
    412     if test.is_gpu_available(cuda_only=True):
    413       for dtype in [np.float16, np.float32]:
    414         self._test_training(
    415             x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
    416 
    417   def testTrainingShape4(self):
    418     x_shape = [27, 131, 127, 6]
    419     for dtype in [np.float16, np.float32]:
    420       if test.is_gpu_available(cuda_only=True):
    421         self._test_training(
    422             x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
    423         self._test_training(
    424             x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
    425       self._test_training(
    426           x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
    427 
    428   def testTrainingShape5(self):
    429     x_shape = [0, 131, 127, 6]
    430     for dtype in [np.float16, np.float32]:
    431       if test.is_gpu_available(cuda_only=True):
    432         self._test_training(
    433             x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
    434         self._test_training(
    435             x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
    436       self._test_training(
    437           x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
    438 
    439   def testBatchNormGradShape1(self):
    440     for is_training in [True, False]:
    441       x_shape = [1, 1, 6, 1]
    442       for dtype in [np.float16, np.float32]:
    443         if test.is_gpu_available(cuda_only=True):
    444           self._test_gradient(
    445               x_shape,
    446               dtype, [1],
    447               np.float32,
    448               use_gpu=True,
    449               data_format='NHWC',
    450               is_training=is_training)
    451           self._test_gradient(
    452               x_shape,
    453               dtype, [1],
    454               np.float32,
    455               use_gpu=True,
    456               data_format='NCHW',
    457               is_training=is_training)
    458         self._test_gradient(
    459             x_shape,
    460             dtype, [1],
    461             np.float32,
    462             use_gpu=False,
    463             data_format='NHWC',
    464             is_training=is_training)
    465 
    466   def testBatchNormGradShape2(self):
    467     for is_training in [True, False]:
    468       x_shape = [1, 1, 6, 2]
    469       for dtype in [np.float16, np.float32]:
    470         if test.is_gpu_available(cuda_only=True):
    471           self._test_gradient(
    472               x_shape,
    473               dtype, [2],
    474               np.float32,
    475               use_gpu=True,
    476               data_format='NHWC',
    477               is_training=is_training)
    478         self._test_gradient(
    479             x_shape,
    480             dtype, [2],
    481             np.float32,
    482             use_gpu=False,
    483             data_format='NHWC',
    484             is_training=is_training)
    485 
    486   def testBatchNormGradShape3(self):
    487     for is_training in [True, False]:
    488       x_shape = [1, 2, 1, 6]
    489       if test.is_gpu_available(cuda_only=True):
    490         for dtype in [np.float16, np.float32]:
    491           self._test_gradient(
    492               x_shape,
    493               dtype, [2],
    494               np.float32,
    495               use_gpu=True,
    496               data_format='NCHW',
    497               is_training=is_training)
    498 
    499   def testBatchNormGradShape4(self):
    500     for is_training in [True, False]:
    501       x_shape = [5, 7, 11, 4]
    502       for dtype in [np.float16, np.float32]:
    503         if test.is_gpu_available(cuda_only=True):
    504           self._test_gradient(
    505               x_shape,
    506               dtype, [7],
    507               np.float32,
    508               use_gpu=True,
    509               data_format='NCHW',
    510               is_training=is_training)
    511           self._test_gradient(
    512               x_shape,
    513               dtype, [4],
    514               np.float32,
    515               use_gpu=True,
    516               data_format='NHWC',
    517               is_training=is_training)
    518         self._test_gradient(
    519             x_shape,
    520             dtype, [4],
    521             np.float32,
    522             use_gpu=False,
    523             data_format='NHWC',
    524             is_training=is_training)
    525 
    526   def testBatchNormGradShape5(self):
    527     for is_training in [True, False]:
    528       x_shape = [0, 7, 11, 4]
    529       for dtype in [np.float16, np.float32]:
    530         if test.is_gpu_available(cuda_only=True):
    531           self._test_gradient(
    532               x_shape,
    533               dtype, [7],
    534               np.float32,
    535               use_gpu=True,
    536               data_format='NCHW',
    537               is_training=is_training)
    538           self._test_gradient(
    539               x_shape,
    540               dtype, [4],
    541               np.float32,
    542               use_gpu=True,
    543               data_format='NHWC',
    544               is_training=is_training)
    545         self._test_gradient(
    546             x_shape,
    547             dtype, [4],
    548             np.float32,
    549             use_gpu=False,
    550             data_format='NHWC',
    551             is_training=is_training)
    552 
    553   def _testBatchNormGradGrad(self, config):
    554     shape = config['shape']
    555     err_tolerance = config['err_tolerance']
    556     dtype = config['dtype']
    557     for is_training in [True, False]:
    558       if test.is_gpu_available(cuda_only=True):
    559         self._test_grad_grad(
    560             shape,
    561             dtype, [shape[3]],
    562             np.float32,
    563             use_gpu=True,
    564             data_format='NHWC',
    565             is_training=is_training,
    566             err_tolerance=err_tolerance)
    567         self._test_grad_grad(
    568             shape,
    569             dtype, [shape[1]],
    570             np.float32,
    571             use_gpu=True,
    572             data_format='NCHW',
    573             is_training=is_training,
    574             err_tolerance=err_tolerance)
    575       self._test_grad_grad(
    576           shape,
    577           dtype, [shape[3]],
    578           np.float32,
    579           use_gpu=False,
    580           data_format='NHWC',
    581           is_training=is_training,
    582           err_tolerance=err_tolerance)
    583 
    584   def testBatchNormGradGradConfig1(self):
    585     config = {
    586         'shape': [2, 3, 4, 5],
    587         'err_tolerance': 1e-2,
    588         'dtype': np.float32,
    589     }
    590     self._testBatchNormGradGrad(config)
    591 
    592   def testBatchNormGradGradConfig2(self):
    593     config = {
    594         'shape': [2, 3, 2, 2],
    595         'err_tolerance': 1e-3,
    596         'dtype': np.float32,
    597     }
    598     self._testBatchNormGradGrad(config)
    599 
    600   def testBatchNormGradGradConfig3(self):
    601     config = {
    602         'shape': [2, 3, 4, 5],
    603         'err_tolerance': 1e-2,
    604         'dtype': np.float16,
    605     }
    606     self._testBatchNormGradGrad(config)
    607 
    608   def testBatchNormGradGradConfig4(self):
    609     config = {
    610         'shape': [2, 3, 2, 2],
    611         'err_tolerance': 2e-3,
    612         'dtype': np.float16,
    613     }
    614     self._testBatchNormGradGrad(config)
    615 
    616 
    617 if __name__ == '__main__':
    618   test.main()
    619