Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for batch_norm related functionality in tensorflow.ops.nn."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gen_nn_ops
     30 from tensorflow.python.ops import gradient_checker
     31 from tensorflow.python.ops import gradients_impl
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import nn_impl
     34 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     35 from tensorflow.python.platform import test
     36 
     37 
     38 @test_util.with_c_api
     39 class BatchNormalizationTest(test.TestCase):
     40 
     41   def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
     42                    scale_after_normalization, shift_after_normalization):
     43     y = (x - m) / np.sqrt(v + epsilon)
     44     y = y * gamma if scale_after_normalization else y
     45     return y + beta if shift_after_normalization else y
     46 
     47   def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
     48                     scale_after_normalization, shift_after_normalization):
     49     y = (x - m) * math_ops.rsqrt(v + epsilon)
     50     if scale_after_normalization:
     51       y = gamma * y
     52     return y + beta if shift_after_normalization else y
     53 
     54   def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
     55                      scale_after_normalization):
     56     """Original implementation."""
     57     test_util.set_producer_version(ops.get_default_graph(), 8)
     58     return gen_nn_ops._batch_norm_with_global_normalization(
     59         x, m, v, beta, gamma, epsilon, scale_after_normalization)
     60     # pylint: enable=protected-access
     61 
     62   def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
     63                        scale_after_normalization):
     64     """Re-implementation of the original kernel for backward compatibility."""
     65     return nn_impl.batch_norm_with_global_normalization(
     66         x, m, v, beta, gamma, epsilon, scale_after_normalization)
     67 
     68   def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
     69                      scale_after_normalization, shift_after_normalization):
     70     """New implementation."""
     71     return nn_impl.batch_normalization(x, m, v, beta if
     72                                        shift_after_normalization else None,
     73                                        gamma if scale_after_normalization else
     74                                        None, epsilon)
     75 
     76   def testBatchNorm(self):
     77     x_shape = [3, 5, 4, 2]
     78     param_shape = [2]
     79     x_val = np.random.random_sample(x_shape).astype(np.float32)
     80     m_val = np.random.random_sample(param_shape).astype(np.float32)
     81     v_val = np.random.random_sample(param_shape).astype(np.float32)
     82     beta_val = np.random.random_sample(param_shape).astype(np.float32)
     83     gamma_val = np.random.random_sample(param_shape).astype(np.float32)
     84     for use_gpu in [True, False]:
     85       with self.test_session(use_gpu=use_gpu) as sess:
     86         x = constant_op.constant(x_val, name="x")
     87         m = constant_op.constant(m_val, name="m")
     88         v = constant_op.constant(v_val, name="v")
     89         beta = constant_op.constant(beta_val, name="beta")
     90         gamma = constant_op.constant(gamma_val, name="gamma")
     91         epsilon = 0.001
     92         for scale_after_normalization in [True, False]:
     93           for shift_after_normalization in [True, False]:
     94             bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
     95                                       scale_after_normalization,
     96                                       shift_after_normalization)
     97             bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon,
     98                                           scale_after_normalization)
     99             bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
    100                                       scale_after_normalization)
    101             on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
    102                                     scale_after_normalization,
    103                                     shift_after_normalization)
    104             np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val,
    105                                       epsilon, scale_after_normalization,
    106                                       shift_after_normalization)
    107             tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
    108                 [bn2, bn1bw, bn1, on])
    109             self.assertAllClose(np_bn, ops_bn, atol=0.00001)
    110             self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001)
    111             self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001)
    112             # shift_after_normalization=False is not supported in v1.
    113             if shift_after_normalization:
    114               self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001)
    115               self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001)
    116               self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001)
    117               self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001)
    118 
    119   def _testBatchNormGradient(self,
    120                              param_index,
    121                              tag,
    122                              scale_after_normalization,
    123                              shift_after_normalization,
    124                              version,
    125                              err_tolerance=1e-11):
    126     x_shape = [3, 5, 4, 5]
    127     param_shape = [5]
    128     np.random.seed(1)  # Make it reproducible.
    129     x_val = np.random.random_sample(x_shape).astype(np.float64)
    130     m_val = np.random.random_sample(param_shape).astype(np.float64)
    131     v_val = np.random.random_sample(param_shape).astype(np.float64)
    132     beta_val = np.random.random_sample(param_shape).astype(np.float64)
    133     gamma_val = np.random.random_sample(param_shape).astype(np.float64)
    134     with self.test_session():
    135       x = constant_op.constant(x_val, name="x")
    136       m = constant_op.constant(m_val, name="m")
    137       v = constant_op.constant(v_val, name="v")
    138       beta = constant_op.constant(beta_val, name="beta")
    139       gamma = constant_op.constant(gamma_val, name="gamma")
    140       epsilon = 0.001
    141       if version == 1:
    142         output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
    143                                      scale_after_normalization)
    144       elif version == 2:
    145         output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
    146                                      scale_after_normalization,
    147                                      shift_after_normalization)
    148       else:
    149         print("Invalid version", version)
    150         raise ValueError()
    151       all_params = [x, m, v, beta, gamma]
    152       all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
    153       err = gradient_checker.compute_gradient_error(all_params[param_index],
    154                                                     all_shapes[param_index],
    155                                                     output, x_shape)
    156     print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
    157           (version, tag, "with" if scale_after_normalization else "without",
    158            "with" if shift_after_normalization else "without"), err)
    159     self.assertLess(err, err_tolerance)
    160 
    161   def _testBatchNormGradientInAllNeedConfigs(self,
    162                                              param_index,
    163                                              tag,
    164                                              err_tolerance=1e-11):
    165     for scale_after_normalization in [True, False]:
    166       for shift_after_normalization in [True, False]:
    167         # shift_after_normalization=False is not supported in version 1.
    168         for v in ([1, 2] if shift_after_normalization else [2]):
    169           self._testBatchNormGradient(param_index, tag,
    170                                       scale_after_normalization,
    171                                       shift_after_normalization, v,
    172                                       err_tolerance)
    173 
    174   def testBatchNormInputGradient(self):
    175     self._testBatchNormGradientInAllNeedConfigs(0, "x")
    176 
    177   def testBatchNormMeanGradient(self):
    178     self._testBatchNormGradientInAllNeedConfigs(1, "mean")
    179 
    180   def testBatchNormVarianceGradient(self):
    181     self._testBatchNormGradientInAllNeedConfigs(
    182         2, "variance", err_tolerance=1e-03)
    183 
    184   def testBatchNormBetaGradient(self):
    185     # Since beta does not exist when scale_after_normalization=False, we only
    186     # test for scale_after_normalization=True.
    187     for scale_after_normalization in [True, False]:
    188       for v in [1, 2]:
    189         self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
    190                                     v)
    191 
    192   def testBatchNormGammaGradient(self):
    193     # If scale_after_normalization is False, backprop for gamma in v1
    194     # will be 0. In version 2 of the API, if scale_after_normalization is False,
    195     # gamma is not used at all, and the gradient is None, which displeases the
    196     # gradient checker.
    197     for scale_after_normalization in [True, False]:
    198       self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
    199                                   1)
    200     for shift_after_normalization in [True, False]:
    201       self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
    202                                   2)
    203 
    204   def testBatchNormGradImpl(self):
    205     x_shape = [7, 5, 4, 6]
    206     param_shape = [6]
    207     np.random.seed(1)  # Make it reproducible.
    208     x_val = np.random.random_sample(x_shape).astype(np.float32)
    209     m_val = np.random.random_sample(param_shape).astype(np.float32)
    210     v_val = np.random.random_sample(param_shape).astype(np.float32)
    211     beta_val = np.random.random_sample(param_shape).astype(np.float32)
    212     gamma_val = np.random.random_sample(param_shape).astype(np.float32)
    213     backprop_val = np.random.random_sample(x_shape).astype(np.float32)
    214     for use_gpu in [False, True]:
    215       with self.test_session(use_gpu=use_gpu) as sess:
    216         x = constant_op.constant(x_val, name="x")
    217         m = constant_op.constant(m_val, name="m")
    218         v = constant_op.constant(v_val, name="v")
    219         beta = constant_op.constant(beta_val, name="beta")
    220         gamma = constant_op.constant(gamma_val, name="gamma")
    221         backprop = constant_op.constant(backprop_val, name="backprop")
    222         epsilon = 0.001
    223         for scale_after_normalization in [True, False]:
    224           # _batch_norm_with_global_normalization_grad is deprecated in v9
    225           test_util.set_producer_version(ops.get_default_graph(), 8)
    226           grad = gen_nn_ops._batch_norm_with_global_normalization_grad(
    227               x, m, v, gamma, backprop, epsilon, scale_after_normalization)
    228           dx, dm, dv, db, dg = grad
    229           self.assertEqual(grad.dx, dx)
    230           self.assertEqual(grad.dm, dm)
    231           self.assertEqual(grad.dv, dv)
    232           self.assertEqual(grad.db, db)
    233           self.assertEqual(grad.dg, dg)
    234 
    235           on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
    236                                   scale_after_normalization, True)
    237           odx, odm, odv, odb, odg = gradients_impl.gradients(
    238               [on], [x, m, v, beta, gamma], [backprop])
    239           if scale_after_normalization:
    240             all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
    241             to_check = ["dx", "dm", "dv", "db", "dg"]
    242           else:
    243             all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
    244             to_check = ["dx", "dm", "dv", "db"]
    245           for i, _ in enumerate(to_check):
    246             self.assertAllClose(
    247                 all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
    248 
    249   def testBatchNormKeepDims(self):
    250     """Test for tf.nn.moments(..., keep_dims=True / False).
    251 
    252     Make sure that parameters with shape (1, 1, 1, depth) yield the same
    253     result as parameters with shape (depth)
    254     """
    255     x_shape = (3, 5, 4, 2)
    256     param_shape = (2)
    257     keep_dims_param_shape = (1, 1, 1, 2)
    258     x_val = np.random.random_sample(x_shape).astype(np.float32)
    259     m_val = np.random.random_sample(param_shape).astype(np.float32)
    260     v_val = np.random.random_sample(param_shape).astype(np.float32)
    261     beta_val = np.random.random_sample(param_shape).astype(np.float32)
    262     gamma_val = np.random.random_sample(param_shape).astype(np.float32)
    263     for use_gpu in [True, False]:
    264       with self.test_session(use_gpu=use_gpu) as sess:
    265         x = constant_op.constant(x_val, name="x")
    266         m = constant_op.constant(m_val, name="m")
    267         v = constant_op.constant(v_val, name="v")
    268         beta = constant_op.constant(beta_val, name="beta")
    269         gamma = constant_op.constant(gamma_val, name="gamma")
    270         keep_dims_m = array_ops.reshape(
    271             m, keep_dims_param_shape, name="keep_dims_m")
    272         keep_dims_v = array_ops.reshape(
    273             v, keep_dims_param_shape, name="keep_dims_v")
    274         keep_dims_beta = array_ops.reshape(
    275             beta, keep_dims_param_shape, name="keep_dims_beta")
    276         keep_dims_gamma = array_ops.reshape(
    277             gamma, keep_dims_param_shape, name="keep_dims_gamma")
    278         epsilon = 0.001
    279         for scale_after_normalization in [True, False]:
    280           for shift_after_normalization in [True, False]:
    281             bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
    282                                      scale_after_normalization,
    283                                      shift_after_normalization)
    284             keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v,
    285                                                keep_dims_beta, keep_dims_gamma,
    286                                                epsilon,
    287                                                scale_after_normalization,
    288                                                shift_after_normalization)
    289             tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
    290                 [bn, keep_dims_bn])
    291             self.assertEquals(x_shape, tf_batch_norm.shape)
    292             self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
    293             self.assertAllClose(
    294                 tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
    295 
    296   def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001):
    297     x_val = np.random.random_sample(x_shape).astype(np.float32)
    298     m_val = np.random.random_sample(param_shape).astype(np.float32)
    299     v_val = np.random.random_sample(param_shape).astype(np.float32)
    300     beta_val = np.random.random_sample(param_shape).astype(np.float32)
    301     gamma_val = np.random.random_sample(param_shape).astype(np.float32)
    302     for use_gpu in [True, False]:
    303       with self.test_session(use_gpu=use_gpu) as sess:
    304         x = constant_op.constant(x_val, name="x")
    305         m = constant_op.constant(m_val, name="m")
    306         v = constant_op.constant(v_val, name="v")
    307         beta = constant_op.constant(beta_val, name="beta")
    308         gamma = constant_op.constant(gamma_val, name="gamma")
    309         epsilon = 0.001
    310         for scale_after_normalization in [True, False]:
    311           for shift_after_normalization in [True, False]:
    312             bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
    313                                      scale_after_normalization,
    314                                      shift_after_normalization)
    315             np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val,
    316                                               gamma_val, epsilon,
    317                                               scale_after_normalization,
    318                                               shift_after_normalization)
    319             [tf_batch_norm] = sess.run([bn])
    320             self.assertEquals(x_shape, np_batch_norm.shape)
    321             self.assertEquals(x_shape, tf_batch_norm.shape)
    322             self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
    323 
    324   def testBatchNormArbitraryShapes(self):
    325     """Test for a variety of shapes and moments.
    326 
    327     Batch normalization is expected to work regardless of the position and
    328     dimensionality of the 'depth' axis/axes.
    329     """
    330     self._testBatchNormArbitraryShapes((3, 3), (1, 3))
    331     self._testBatchNormArbitraryShapes((3, 3), (3, 1))
    332     self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
    333     self._testBatchNormArbitraryShapes(
    334         (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005)
    335 
    336 
    337 @test_util.with_c_api
    338 class SufficientStatisticsTest(test.TestCase):
    339 
    340   def _npSuffStats(self, x, axes, shift, keep_dims):
    341     axis = tuple(axes)
    342     if shift is not None:
    343       m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
    344       v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
    345     else:
    346       m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
    347       v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
    348     count = 1.0
    349     for d in xrange(x.ndim):
    350       if d in set(axes):
    351         count *= x.shape[d]
    352     if not keep_dims:
    353       shift = np.squeeze(shift, axis=axis)
    354     return count, m_ss, v_ss, shift
    355 
    356   def _opSuffStats(self, x, axes, shift, keep_dims):
    357     return nn_impl.sufficient_statistics(x, axes, shift, keep_dims)
    358 
    359   def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
    360     x_val = np.random.random_sample(x_shape).astype(np.float32)
    361     np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
    362     for use_gpu in [True, False]:
    363       with self.test_session(use_gpu=use_gpu) as sess:
    364         if has_shape:
    365           x = constant_op.constant(x_val, name="x")
    366           x.set_shape(x_shape)
    367           op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
    368           if shift:
    369             tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
    370           else:
    371             tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
    372         else:
    373           x = array_ops.placeholder(
    374               dtype=dtypes.float32, shape=[None] * len(x_shape), name="x")
    375           op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
    376           if shift:
    377             tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s],
    378                                               feed_dict={x: x_val})
    379           else:
    380             tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v],
    381                                         feed_dict={x: x_val})
    382         self.assertAllClose(np_c, tf_c, atol=0.000001)
    383         self.assertAllClose(np_m, tf_m, atol=0.000001)
    384         self.assertAllClose(np_v, tf_v, atol=0.000001)
    385         if shift:
    386           self.assertAllClose(np_s, tf_s, atol=0.000001)
    387 
    388   def testSuffStats(self):
    389     for has_shape in [True, False]:
    390       for keep_dims in [True, False]:
    391         for shift in [None, 1.0]:
    392           self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
    393           self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
    394           self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
    395 
    396 
    397 @test_util.with_c_api
    398 class NormalizeMomentsTest(test.TestCase):
    399 
    400   def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
    401     mean = mean_ss / counts
    402     variance = variance_ss / counts - mean * mean
    403     if shift is not None:
    404       mean += shift
    405     return mean, variance
    406 
    407   def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
    408     return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift)
    409 
    410   def _testNormalizeMoments(self, shape, shift):
    411     counts = np.ones([1]).astype(np.float32)
    412     mean_ss = np.random.random_sample(shape).astype(np.float32)
    413     variance_ss = np.random.random_sample(shape).astype(np.float32)
    414     variance_ss *= variance_ss
    415     if shift:
    416       shift_v = np.random.random_sample(shape).astype(np.float32)
    417     else:
    418       shift_v = None
    419     npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
    420     for use_gpu in [True, False]:
    421       with self.test_session(use_gpu=use_gpu) as sess:
    422         tf_counts = constant_op.constant(counts, name="counts")
    423         tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss")
    424         tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss")
    425         if shift:
    426           tf_shift_v = constant_op.constant(shift_v, name="shift")
    427         else:
    428           tf_shift_v = None
    429         opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
    430                                             tf_variance_ss, tf_shift_v)
    431         tfm, tfv = sess.run([opm, opv])
    432         self.assertAllClose(npm, tfm, atol=0.000001)
    433         self.assertAllClose(npv, tfv, atol=0.000001)
    434 
    435   def testNormalizeMoments(self):
    436     for shift in [None, 4.0]:
    437       self._testNormalizeMoments([3], shift)
    438       self._testNormalizeMoments([2, 3], shift)
    439 
    440 
    441 @test_util.with_c_api
    442 class MomentsTest(test.TestCase):
    443 
    444   def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
    445     # Method to compute moments of `x` wrt `axes`.
    446     #
    447     # This is exposed so WeightedMomentsTest can inherit the tests and
    448     # assertions from MomentsTest; the extra_out_grads argument allows
    449     # its inherited gradient tests to assert gradients against the
    450     # weights as well as the input values.
    451 
    452     return nn_impl.moments(x, axes, keep_dims=keep_dims)
    453 
    454   def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
    455     with self.test_session():
    456       # shape = [batch, width, height, depth]
    457       assert len(shape) == 4
    458 
    459       x_numpy = np.random.normal(size=shape).astype(np.float32)
    460       x = array_ops.placeholder(dtype, shape=[None] * len(shape))
    461 
    462       mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
    463 
    464       num_elements = np.prod([shape[i] for i in axes])
    465 
    466       ax = tuple(axes)
    467       expected_mean = np.sum(x_numpy, axis=ax,
    468                              keepdims=keep_dims) / num_elements
    469       expected_mean_squared = np.multiply(expected_mean, expected_mean)
    470       expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
    471                                   axis=ax,
    472                                   keepdims=keep_dims) / num_elements
    473       expected_variance = expected_x_squared - expected_mean_squared
    474 
    475       # Check that the moments are correct.
    476       self.assertAllCloseAccordingToType(
    477           expected_mean, mean.eval(feed_dict={x: x_numpy}))
    478       self.assertAllCloseAccordingToType(
    479           expected_variance, var.eval(feed_dict={x: x_numpy}))
    480 
    481   def RunMomentTest(self, shape, axes, keep_dims, dtype):
    482     with self.test_session():
    483       # shape = [batch, width, height, depth]
    484       assert len(shape) == 4
    485 
    486       x_numpy = np.random.normal(size=shape).astype(np.float32)
    487       x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype)
    488 
    489       # Compute the expected values at high precision since the method
    490       # is prone to catastrophic cancellation:
    491       x_numpy = x_numpy.astype(np.float128)
    492 
    493       mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
    494 
    495       num_elements = np.prod([shape[i] for i in axes])
    496 
    497       ax = tuple(axes)
    498       expected_mean = np.sum(x_numpy, axis=ax,
    499                              keepdims=keep_dims) / num_elements
    500       expected_mean_squared = np.multiply(expected_mean, expected_mean)
    501       expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
    502                                   axis=ax,
    503                                   keepdims=keep_dims) / num_elements
    504       expected_variance = expected_x_squared - expected_mean_squared
    505 
    506       # Check that the moments are correct.
    507       self.assertAllCloseAccordingToType(expected_mean, mean.eval())
    508       self.assertAllCloseAccordingToType(expected_variance, var.eval())
    509 
    510   def testBasic(self):
    511     for keep_dims in [False, True]:
    512       for dtype in [dtypes.float32, dtypes.float16]:
    513         self.RunMomentTest(
    514             shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
    515         self.RunMomentTestWithDynamicShape(
    516             shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
    517 
    518   def testGlobalNormalization(self):
    519     for keep_dims in [False, True]:
    520       for dtype in [dtypes.float32, dtypes.float16]:
    521         self.RunMomentTest(
    522             shape=[2, 3, 5, 4],
    523             axes=[0, 1, 2],
    524             keep_dims=keep_dims,
    525             dtype=dtype)
    526         self.RunMomentTestWithDynamicShape(
    527             shape=[2, 3, 5, 4],
    528             axes=[0, 1, 2],
    529             keep_dims=keep_dims,
    530             dtype=dtype)
    531 
    532   def testAxes(self):
    533     for keep_dims in [False, True]:
    534       for dtype in [dtypes.float32, dtypes.float16]:
    535         self.RunMomentTest(
    536             shape=[2, 3, 5, 4],
    537             axes=[1, 2, 3],
    538             keep_dims=keep_dims,
    539             dtype=dtype)
    540         self.RunMomentTestWithDynamicShape(
    541             shape=[2, 3, 5, 4],
    542             axes=[1, 2, 3],
    543             keep_dims=keep_dims,
    544             dtype=dtype)
    545 
    546   def _testGlobalGradient(self, from_y="mean"):
    547     with self.test_session():
    548       x_shape = [3, 5, 4, 2]
    549       x_val = np.random.random_sample(x_shape).astype(np.float64)
    550       x = constant_op.constant(x_val)
    551       x.set_shape(x_shape)
    552 
    553       axes = [0, 1, 2]
    554       y_shape = [2]  # Depth of x
    555 
    556       inputs_to_compute_gradients_for = [x]
    557 
    558       out_mean, out_var = self._unweighted_moments(
    559           x, axes, extra_out_grads=inputs_to_compute_gradients_for)
    560       if from_y == "mean":
    561         y = out_mean
    562       elif from_y == "var":
    563         y = out_var
    564 
    565       for (i, v) in enumerate(inputs_to_compute_gradients_for):
    566         err = gradient_checker.compute_gradient_error(v,
    567                                                       v.get_shape().as_list(),
    568                                                       y, y_shape)
    569         print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
    570         self.assertLess(err, 1e-11)
    571 
    572   def testMeanGlobalGradient(self):
    573     self._testGlobalGradient(from_y="mean")
    574 
    575   def testVarGlobalGradient(self):
    576     self._testGlobalGradient(from_y="var")
    577 
    578 
    579 @test_util.with_c_api
    580 class WeightedMomentsTest(MomentsTest):
    581   """Tests for nn.weighted_moments.
    582 
    583   Note that this test inherits from MomentsTest, inheriting all its
    584   test methods!
    585 
    586   It modifies MomentsTest in two ways:
    587 
    588   a) By overriding _unweighted_moments, all the codepaths in
    589      MomentsTest are executed, but with calls to tf.nn.moments()
    590      replaced by calls to tf.nn.weighted_moments() with a constant
    591      weight of 1.
    592 
    593   b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
    594      this test adds multiple additional calls to
    595      RunWeightedMomentsTest() to exercise correctness with
    596      non-constant weights and varying broadcasting situations. (It
    597      also continues to call MomentsTest.Run(Weighted)?MomentsTest as
    598      well.)
    599 
    600   """
    601 
    602   def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
    603     weights = constant_op.constant(1, dtype=x.dtype)
    604     if extra_out_grads is not None:
    605       # We want to assert gradients WRT weights as well as X!
    606       extra_out_grads.append(weights)
    607     return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims)
    608 
    609   def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
    610     if not dynshapes:
    611       super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims,
    612                                                      dtype)
    613     else:
    614       super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape,
    615                                                                      axes,
    616                                                                      keep_dims,
    617                                                                      dtype)
    618 
    619     # 1:1 weights and inputs
    620     self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
    621 
    622     # Various broadcasting combinations
    623     for idx in range(len(shape)):
    624       # try broadcasting weights in all positions
    625       weight_shape = [1] * len(shape)
    626       weight_shape[idx] = shape[idx]
    627 
    628       self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
    629 
    630       # Also try broadcasting with a suffix of length n
    631       weight_shape = shape[-(idx + 1):]
    632       self.RunWeightedMomentTest(
    633           shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
    634 
    635   def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
    636     self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
    637 
    638   def RunWeightedMomentTest(self,
    639                             shape,
    640                             weights_shape,
    641                             axes,
    642                             keep_dims,
    643                             dtype,
    644                             dynshapes=False):
    645     with self.test_session() as s:
    646       x_numpy = np.random.normal(size=shape).astype(np.float32)
    647       weights_numpy = np.absolute(  # weights must be positive
    648           np.random.normal(
    649               size=weights_shape, loc=1.0).astype(np.float32))
    650 
    651       # Expand the numpy version to higher precision
    652       x_numpy = x_numpy.astype(np.float128)
    653       weights_numpy = weights_numpy.astype(np.float128)
    654 
    655       x_shape = [None] * len(shape) if dynshapes else shape
    656       weights_shape = ([None] * len(weights_shape) if dynshapes else
    657                        weights_shape)
    658 
    659       x = array_ops.placeholder(dtype, shape=x_shape)
    660       weights = array_ops.placeholder(dtype, shape=weights_shape)
    661 
    662       mean, var = nn_impl.weighted_moments(
    663           x, axes, weights, keep_dims=keep_dims)
    664 
    665       ax = tuple(axes)
    666 
    667       def _np_weighted_sum(v):
    668         return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
    669 
    670       weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
    671       expected_mean = _np_weighted_sum(x_numpy) / weight_sum
    672       expected_mean_squared = np.multiply(expected_mean, expected_mean)
    673       expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) /
    674                             weight_sum)
    675       expected_variance = expected_x_squared - expected_mean_squared
    676 
    677       mean_v, var_v = s.run([mean, var],
    678                             feed_dict={x: x_numpy,
    679                                        weights: weights_numpy})
    680 
    681       self.assertAllCloseAccordingToType(expected_mean, mean_v)
    682       self.assertAllCloseAccordingToType(expected_variance, var_v)
    683 
    684 
    685 if __name__ == "__main__":
    686   test.main()
    687