Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for fused batch norm operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import gen_nn_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.ops import nn
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class FusedBatchNormTest(XLATestCase):
     32 
     33   def _reference_training(self, x, scale, offset, epsilon, data_format):
     34     if data_format != "NHWC":
     35       raise ValueError("data_format must be NHWC, got %s." % data_format)
     36     x_square = x * x
     37     x_square_sum = np.sum(x_square, (0, 1, 2))
     38     x_sum = np.sum(x, axis=(0, 1, 2))
     39     element_count = np.size(x) / int(np.shape(x)[-1])
     40     mean = x_sum / element_count
     41     var = x_square_sum / element_count - mean * mean
     42     normalized = (x - mean) / np.sqrt(var + epsilon)
     43     return (normalized * scale + offset), mean, var
     44 
     45   def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
     46     # Use the following formulas to calculate gradients:
     47     # grad_scale =
     48     #   sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
     49     #
     50     # grad_offset = sum(output_y)
     51     #
     52     # grad_x =
     53     #   1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
     54     #   (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
     55     if data_format != "NHWC":
     56       raise ValueError("data_format must be NHWC, got %s." % data_format)
     57     grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) -
     58                       (x - mean) * np.mean(grad_y *
     59                                            (x - mean), axis=(0, 1, 2)) /
     60                       (var + epsilon)) / np.sqrt(var + epsilon)
     61     grad_scale = np.sum(
     62         grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2))
     63     grad_offset = np.sum(grad_y, axis=(0, 1, 2))
     64     return grad_x, grad_scale, grad_offset
     65 
     66   def testInference(self):
     67     channel = 3
     68     x_shape = [2, 2, 6, channel]
     69     scale_shape = [channel]
     70     x_val = np.random.random_sample(x_shape).astype(np.float32)
     71     scale_val = np.random.random_sample(scale_shape).astype(np.float32)
     72 
     73     offset_val = np.random.random_sample(scale_shape).astype(np.float32)
     74     data_format = "NHWC"
     75     with self.test_session() as sess, self.test_scope():
     76       # To avoid constant folding
     77       t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
     78       scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
     79       offset = array_ops.placeholder(
     80           np.float32, shape=scale_shape, name="offset")
     81       epsilon = 0.001
     82       y_ref, mean_ref, var_ref = self._reference_training(
     83           x_val, scale_val, offset_val, epsilon, data_format)
     84       y, mean, variance = nn.fused_batch_norm(
     85           t_val,
     86           scale,
     87           offset,
     88           mean=mean_ref,
     89           variance=var_ref,
     90           epsilon=epsilon,
     91           data_format=data_format,
     92           is_training=False)
     93 
     94       y_val, _, _ = sess.run(
     95           [y, mean,
     96            variance], {t_val: x_val,
     97                        scale: scale_val,
     98                        offset: offset_val})
     99       self.assertAllClose(y_val, y_ref, atol=1e-3)
    100 
    101   def _testLearning(self, use_gradient_checker):
    102     channel = 3
    103     x_shape = [2, 2, 6, channel]
    104     scale_shape = [channel]
    105     x_val = np.random.random_sample(x_shape).astype(np.float32)
    106     scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    107 
    108     offset_val = np.random.random_sample(scale_shape).astype(np.float32)
    109     mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    110     var_val = np.random.random_sample(scale_shape).astype(np.float32)
    111     data_format = "NHWC"
    112     with self.test_session() as sess, self.test_scope():
    113       # To avoid constant folding
    114       t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
    115       scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
    116       offset = array_ops.placeholder(
    117           np.float32, shape=scale_shape, name="offset")
    118       epsilon = 0.001
    119       y, mean, var = nn.fused_batch_norm(
    120           t_val,
    121           scale,
    122           offset,
    123           mean=None,
    124           variance=None,
    125           epsilon=epsilon,
    126           data_format=data_format,
    127           is_training=True)
    128       # Check gradient.
    129       if use_gradient_checker:
    130         err = gradient_checker.compute_gradient_error(
    131             t_val,
    132             x_shape,
    133             y,
    134             x_shape,
    135             extra_feed_dict={
    136                 t_val: x_val,
    137                 scale: scale_val,
    138                 offset: offset_val
    139             })
    140         self.assertLess(err, 1e-3)
    141 
    142       y_val, mean_val, var_val = sess.run(
    143           [y, mean, var], {t_val: x_val,
    144                            scale: scale_val,
    145                            offset: offset_val})
    146       y_ref, mean_ref, var_ref = self._reference_training(
    147           x_val, scale_val, offset_val, epsilon, data_format)
    148       self.assertAllClose(mean_val, mean_ref, atol=1e-3)
    149       self.assertAllClose(y_val, y_ref, atol=1e-3)
    150       self.assertAllClose(var_val, var_ref, atol=1e-3)
    151 
    152   def testLearning(self):
    153     self._testLearning(False)
    154 
    155   def testLearningWithGradientChecker(self):
    156     self._testLearning(True)
    157 
    158   def testGradientTraining(self):
    159     # TODO(b/64270657): Use gradient_checker here in addition to comparing with
    160     # this reference implementation.
    161     channel = 3
    162     x_shape = [2, 2, 6, channel]
    163     scale_shape = [channel]
    164     grad_val = np.random.random_sample(x_shape).astype(np.float32)
    165     x_val = np.random.random_sample(x_shape).astype(np.float32)
    166     scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    167     mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    168     var_val = np.random.random_sample(scale_shape).astype(np.float32)
    169     epsilon = 0.001
    170 
    171     with self.test_session() as sess, self.test_scope():
    172       grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
    173       x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
    174       mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
    175       var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
    176       scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
    177       grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
    178           grad, x, scale, mean, var, data_format="NHWC", is_training=True)
    179 
    180       grad_x_val, grad_scale_val, grad_offset_val = sess.run(
    181           [grad_x, grad_scale, grad_offset], {
    182               grad: grad_val,
    183               x: x_val,
    184               mean: mean_val,
    185               var: var_val,
    186               scale: scale_val
    187           })
    188 
    189       grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
    190           x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC")
    191 
    192       self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
    193       self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
    194       self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
    195 
    196   def testGradientInference(self):
    197     # TODO(b/64270657): Use gradient_checker here in addition to comparing with
    198     # this reference implementation.
    199     channel = 3
    200     x_shape = [2, 2, 6, channel]
    201     scale_shape = [channel]
    202     grad_val = np.random.random_sample(x_shape).astype(np.float32)
    203     x_val = np.random.random_sample(x_shape).astype(np.float32)
    204     scale_val = np.random.random_sample(scale_shape).astype(np.float32)
    205     mean_val = np.random.random_sample(scale_shape).astype(np.float32)
    206     var_val = np.random.random_sample(scale_shape).astype(np.float32)
    207 
    208     with self.test_session() as sess, self.test_scope():
    209       grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
    210       x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
    211       mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
    212       var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
    213       scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
    214       with self.test_scope():
    215         out = gen_nn_ops.fused_batch_norm_grad(
    216             grad, x, scale, mean, var, data_format="NHWC", is_training=False)
    217         grad_x, grad_scale, grad_offset, _, _ = out
    218 
    219       ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
    220           grad, x, scale, mean, var, data_format="NHWC", is_training=False)
    221 
    222       grad_x_val, grad_scale_val, grad_offset_val, = sess.run(
    223           [grad_x, grad_scale, grad_offset], {
    224               grad: grad_val,
    225               x: x_val,
    226               mean: mean_val,
    227               var: var_val,
    228               scale: scale_val
    229           })
    230       grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run(
    231           [ref_x, ref_scale, ref_offset], {
    232               grad: grad_val,
    233               x: x_val,
    234               mean: mean_val,
    235               var: var_val,
    236               scale: scale_val
    237           })
    238 
    239       self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
    240       self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
    241       self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
    242 
    243 
    244 if __name__ == "__main__":
    245   test.main()
    246