1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for batch_norm related functionality in tensorflow.ops.nn.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 from six.moves import xrange # pylint: disable=redefined-builtin 23 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import gen_nn_ops 30 from tensorflow.python.ops import gradient_checker 31 from tensorflow.python.ops import gradients_impl 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import nn_impl 34 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 35 from tensorflow.python.platform import test 36 37 38 @test_util.with_c_api 39 class BatchNormalizationTest(test.TestCase): 40 41 def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, 42 scale_after_normalization, shift_after_normalization): 43 y = (x - m) / np.sqrt(v + epsilon) 44 y = y * gamma if scale_after_normalization else y 45 return y + beta if shift_after_normalization else y 46 47 def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, 48 scale_after_normalization, shift_after_normalization): 49 y = (x - m) * math_ops.rsqrt(v + epsilon) 50 if scale_after_normalization: 51 y = gamma * y 52 return y + beta if shift_after_normalization else y 53 54 def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon, 55 scale_after_normalization): 56 """Original implementation.""" 57 test_util.set_producer_version(ops.get_default_graph(), 8) 58 return gen_nn_ops._batch_norm_with_global_normalization( 59 x, m, v, beta, gamma, epsilon, scale_after_normalization) 60 # pylint: enable=protected-access 61 62 def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon, 63 scale_after_normalization): 64 """Re-implementation of the original kernel for backward compatibility.""" 65 return nn_impl.batch_norm_with_global_normalization( 66 x, m, v, beta, gamma, epsilon, scale_after_normalization) 67 68 def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon, 69 scale_after_normalization, shift_after_normalization): 70 """New implementation.""" 71 return nn_impl.batch_normalization(x, m, v, beta if 72 shift_after_normalization else None, 73 gamma if scale_after_normalization else 74 None, epsilon) 75 76 def testBatchNorm(self): 77 x_shape = [3, 5, 4, 2] 78 param_shape = [2] 79 x_val = np.random.random_sample(x_shape).astype(np.float32) 80 m_val = np.random.random_sample(param_shape).astype(np.float32) 81 v_val = np.random.random_sample(param_shape).astype(np.float32) 82 beta_val = np.random.random_sample(param_shape).astype(np.float32) 83 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 84 for use_gpu in [True, False]: 85 with self.test_session(use_gpu=use_gpu) as sess: 86 x = constant_op.constant(x_val, name="x") 87 m = constant_op.constant(m_val, name="m") 88 v = constant_op.constant(v_val, name="v") 89 beta = constant_op.constant(beta_val, name="beta") 90 gamma = constant_op.constant(gamma_val, name="gamma") 91 epsilon = 0.001 92 for scale_after_normalization in [True, False]: 93 for shift_after_normalization in [True, False]: 94 bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 95 scale_after_normalization, 96 shift_after_normalization) 97 bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon, 98 scale_after_normalization) 99 bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 100 scale_after_normalization) 101 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 102 scale_after_normalization, 103 shift_after_normalization) 104 np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val, 105 epsilon, scale_after_normalization, 106 shift_after_normalization) 107 tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run( 108 [bn2, bn1bw, bn1, on]) 109 self.assertAllClose(np_bn, ops_bn, atol=0.00001) 110 self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001) 111 self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001) 112 # shift_after_normalization=False is not supported in v1. 113 if shift_after_normalization: 114 self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001) 115 self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001) 116 self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001) 117 self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001) 118 119 def _testBatchNormGradient(self, 120 param_index, 121 tag, 122 scale_after_normalization, 123 shift_after_normalization, 124 version, 125 err_tolerance=1e-11): 126 x_shape = [3, 5, 4, 5] 127 param_shape = [5] 128 np.random.seed(1) # Make it reproducible. 129 x_val = np.random.random_sample(x_shape).astype(np.float64) 130 m_val = np.random.random_sample(param_shape).astype(np.float64) 131 v_val = np.random.random_sample(param_shape).astype(np.float64) 132 beta_val = np.random.random_sample(param_shape).astype(np.float64) 133 gamma_val = np.random.random_sample(param_shape).astype(np.float64) 134 with self.test_session(): 135 x = constant_op.constant(x_val, name="x") 136 m = constant_op.constant(m_val, name="m") 137 v = constant_op.constant(v_val, name="v") 138 beta = constant_op.constant(beta_val, name="beta") 139 gamma = constant_op.constant(gamma_val, name="gamma") 140 epsilon = 0.001 141 if version == 1: 142 output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 143 scale_after_normalization) 144 elif version == 2: 145 output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 146 scale_after_normalization, 147 shift_after_normalization) 148 else: 149 print("Invalid version", version) 150 raise ValueError() 151 all_params = [x, m, v, beta, gamma] 152 all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] 153 err = gradient_checker.compute_gradient_error(all_params[param_index], 154 all_shapes[param_index], 155 output, x_shape) 156 print("Batch normalization v%d %s gradient %s scale and %s shift err = " % 157 (version, tag, "with" if scale_after_normalization else "without", 158 "with" if shift_after_normalization else "without"), err) 159 self.assertLess(err, err_tolerance) 160 161 def _testBatchNormGradientInAllNeedConfigs(self, 162 param_index, 163 tag, 164 err_tolerance=1e-11): 165 for scale_after_normalization in [True, False]: 166 for shift_after_normalization in [True, False]: 167 # shift_after_normalization=False is not supported in version 1. 168 for v in ([1, 2] if shift_after_normalization else [2]): 169 self._testBatchNormGradient(param_index, tag, 170 scale_after_normalization, 171 shift_after_normalization, v, 172 err_tolerance) 173 174 def testBatchNormInputGradient(self): 175 self._testBatchNormGradientInAllNeedConfigs(0, "x") 176 177 def testBatchNormMeanGradient(self): 178 self._testBatchNormGradientInAllNeedConfigs(1, "mean") 179 180 def testBatchNormVarianceGradient(self): 181 self._testBatchNormGradientInAllNeedConfigs( 182 2, "variance", err_tolerance=1e-03) 183 184 def testBatchNormBetaGradient(self): 185 # Since beta does not exist when scale_after_normalization=False, we only 186 # test for scale_after_normalization=True. 187 for scale_after_normalization in [True, False]: 188 for v in [1, 2]: 189 self._testBatchNormGradient(3, "beta", scale_after_normalization, True, 190 v) 191 192 def testBatchNormGammaGradient(self): 193 # If scale_after_normalization is False, backprop for gamma in v1 194 # will be 0. In version 2 of the API, if scale_after_normalization is False, 195 # gamma is not used at all, and the gradient is None, which displeases the 196 # gradient checker. 197 for scale_after_normalization in [True, False]: 198 self._testBatchNormGradient(4, "gamma", scale_after_normalization, True, 199 1) 200 for shift_after_normalization in [True, False]: 201 self._testBatchNormGradient(4, "gamma", True, shift_after_normalization, 202 2) 203 204 def testBatchNormGradImpl(self): 205 x_shape = [7, 5, 4, 6] 206 param_shape = [6] 207 np.random.seed(1) # Make it reproducible. 208 x_val = np.random.random_sample(x_shape).astype(np.float32) 209 m_val = np.random.random_sample(param_shape).astype(np.float32) 210 v_val = np.random.random_sample(param_shape).astype(np.float32) 211 beta_val = np.random.random_sample(param_shape).astype(np.float32) 212 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 213 backprop_val = np.random.random_sample(x_shape).astype(np.float32) 214 for use_gpu in [False, True]: 215 with self.test_session(use_gpu=use_gpu) as sess: 216 x = constant_op.constant(x_val, name="x") 217 m = constant_op.constant(m_val, name="m") 218 v = constant_op.constant(v_val, name="v") 219 beta = constant_op.constant(beta_val, name="beta") 220 gamma = constant_op.constant(gamma_val, name="gamma") 221 backprop = constant_op.constant(backprop_val, name="backprop") 222 epsilon = 0.001 223 for scale_after_normalization in [True, False]: 224 # _batch_norm_with_global_normalization_grad is deprecated in v9 225 test_util.set_producer_version(ops.get_default_graph(), 8) 226 grad = gen_nn_ops._batch_norm_with_global_normalization_grad( 227 x, m, v, gamma, backprop, epsilon, scale_after_normalization) 228 dx, dm, dv, db, dg = grad 229 self.assertEqual(grad.dx, dx) 230 self.assertEqual(grad.dm, dm) 231 self.assertEqual(grad.dv, dv) 232 self.assertEqual(grad.db, db) 233 self.assertEqual(grad.dg, dg) 234 235 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 236 scale_after_normalization, True) 237 odx, odm, odv, odb, odg = gradients_impl.gradients( 238 [on], [x, m, v, beta, gamma], [backprop]) 239 if scale_after_normalization: 240 all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg]) 241 to_check = ["dx", "dm", "dv", "db", "dg"] 242 else: 243 all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb]) 244 to_check = ["dx", "dm", "dv", "db"] 245 for i, _ in enumerate(to_check): 246 self.assertAllClose( 247 all_grads[i + len(to_check)], all_grads[i], atol=0.000001) 248 249 def testBatchNormKeepDims(self): 250 """Test for tf.nn.moments(..., keep_dims=True / False). 251 252 Make sure that parameters with shape (1, 1, 1, depth) yield the same 253 result as parameters with shape (depth) 254 """ 255 x_shape = (3, 5, 4, 2) 256 param_shape = (2) 257 keep_dims_param_shape = (1, 1, 1, 2) 258 x_val = np.random.random_sample(x_shape).astype(np.float32) 259 m_val = np.random.random_sample(param_shape).astype(np.float32) 260 v_val = np.random.random_sample(param_shape).astype(np.float32) 261 beta_val = np.random.random_sample(param_shape).astype(np.float32) 262 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 263 for use_gpu in [True, False]: 264 with self.test_session(use_gpu=use_gpu) as sess: 265 x = constant_op.constant(x_val, name="x") 266 m = constant_op.constant(m_val, name="m") 267 v = constant_op.constant(v_val, name="v") 268 beta = constant_op.constant(beta_val, name="beta") 269 gamma = constant_op.constant(gamma_val, name="gamma") 270 keep_dims_m = array_ops.reshape( 271 m, keep_dims_param_shape, name="keep_dims_m") 272 keep_dims_v = array_ops.reshape( 273 v, keep_dims_param_shape, name="keep_dims_v") 274 keep_dims_beta = array_ops.reshape( 275 beta, keep_dims_param_shape, name="keep_dims_beta") 276 keep_dims_gamma = array_ops.reshape( 277 gamma, keep_dims_param_shape, name="keep_dims_gamma") 278 epsilon = 0.001 279 for scale_after_normalization in [True, False]: 280 for shift_after_normalization in [True, False]: 281 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 282 scale_after_normalization, 283 shift_after_normalization) 284 keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v, 285 keep_dims_beta, keep_dims_gamma, 286 epsilon, 287 scale_after_normalization, 288 shift_after_normalization) 289 tf_batch_norm, keep_dims_tf_batch_norm = sess.run( 290 [bn, keep_dims_bn]) 291 self.assertEquals(x_shape, tf_batch_norm.shape) 292 self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape) 293 self.assertAllClose( 294 tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) 295 296 def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001): 297 x_val = np.random.random_sample(x_shape).astype(np.float32) 298 m_val = np.random.random_sample(param_shape).astype(np.float32) 299 v_val = np.random.random_sample(param_shape).astype(np.float32) 300 beta_val = np.random.random_sample(param_shape).astype(np.float32) 301 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 302 for use_gpu in [True, False]: 303 with self.test_session(use_gpu=use_gpu) as sess: 304 x = constant_op.constant(x_val, name="x") 305 m = constant_op.constant(m_val, name="m") 306 v = constant_op.constant(v_val, name="v") 307 beta = constant_op.constant(beta_val, name="beta") 308 gamma = constant_op.constant(gamma_val, name="gamma") 309 epsilon = 0.001 310 for scale_after_normalization in [True, False]: 311 for shift_after_normalization in [True, False]: 312 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 313 scale_after_normalization, 314 shift_after_normalization) 315 np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val, 316 gamma_val, epsilon, 317 scale_after_normalization, 318 shift_after_normalization) 319 [tf_batch_norm] = sess.run([bn]) 320 self.assertEquals(x_shape, np_batch_norm.shape) 321 self.assertEquals(x_shape, tf_batch_norm.shape) 322 self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol) 323 324 def testBatchNormArbitraryShapes(self): 325 """Test for a variety of shapes and moments. 326 327 Batch normalization is expected to work regardless of the position and 328 dimensionality of the 'depth' axis/axes. 329 """ 330 self._testBatchNormArbitraryShapes((3, 3), (1, 3)) 331 self._testBatchNormArbitraryShapes((3, 3), (3, 1)) 332 self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1)) 333 self._testBatchNormArbitraryShapes( 334 (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005) 335 336 337 @test_util.with_c_api 338 class SufficientStatisticsTest(test.TestCase): 339 340 def _npSuffStats(self, x, axes, shift, keep_dims): 341 axis = tuple(axes) 342 if shift is not None: 343 m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims) 344 v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims) 345 else: 346 m_ss = np.sum(x, axis=axis, keepdims=keep_dims) 347 v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) 348 count = 1.0 349 for d in xrange(x.ndim): 350 if d in set(axes): 351 count *= x.shape[d] 352 if not keep_dims: 353 shift = np.squeeze(shift, axis=axis) 354 return count, m_ss, v_ss, shift 355 356 def _opSuffStats(self, x, axes, shift, keep_dims): 357 return nn_impl.sufficient_statistics(x, axes, shift, keep_dims) 358 359 def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape): 360 x_val = np.random.random_sample(x_shape).astype(np.float32) 361 np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims) 362 for use_gpu in [True, False]: 363 with self.test_session(use_gpu=use_gpu) as sess: 364 if has_shape: 365 x = constant_op.constant(x_val, name="x") 366 x.set_shape(x_shape) 367 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 368 if shift: 369 tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s]) 370 else: 371 tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v]) 372 else: 373 x = array_ops.placeholder( 374 dtype=dtypes.float32, shape=[None] * len(x_shape), name="x") 375 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 376 if shift: 377 tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s], 378 feed_dict={x: x_val}) 379 else: 380 tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v], 381 feed_dict={x: x_val}) 382 self.assertAllClose(np_c, tf_c, atol=0.000001) 383 self.assertAllClose(np_m, tf_m, atol=0.000001) 384 self.assertAllClose(np_v, tf_v, atol=0.000001) 385 if shift: 386 self.assertAllClose(np_s, tf_s, atol=0.000001) 387 388 def testSuffStats(self): 389 for has_shape in [True, False]: 390 for keep_dims in [True, False]: 391 for shift in [None, 1.0]: 392 self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) 393 self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) 394 self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) 395 396 397 @test_util.with_c_api 398 class NormalizeMomentsTest(test.TestCase): 399 400 def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 401 mean = mean_ss / counts 402 variance = variance_ss / counts - mean * mean 403 if shift is not None: 404 mean += shift 405 return mean, variance 406 407 def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 408 return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift) 409 410 def _testNormalizeMoments(self, shape, shift): 411 counts = np.ones([1]).astype(np.float32) 412 mean_ss = np.random.random_sample(shape).astype(np.float32) 413 variance_ss = np.random.random_sample(shape).astype(np.float32) 414 variance_ss *= variance_ss 415 if shift: 416 shift_v = np.random.random_sample(shape).astype(np.float32) 417 else: 418 shift_v = None 419 npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v) 420 for use_gpu in [True, False]: 421 with self.test_session(use_gpu=use_gpu) as sess: 422 tf_counts = constant_op.constant(counts, name="counts") 423 tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss") 424 tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss") 425 if shift: 426 tf_shift_v = constant_op.constant(shift_v, name="shift") 427 else: 428 tf_shift_v = None 429 opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss, 430 tf_variance_ss, tf_shift_v) 431 tfm, tfv = sess.run([opm, opv]) 432 self.assertAllClose(npm, tfm, atol=0.000001) 433 self.assertAllClose(npv, tfv, atol=0.000001) 434 435 def testNormalizeMoments(self): 436 for shift in [None, 4.0]: 437 self._testNormalizeMoments([3], shift) 438 self._testNormalizeMoments([2, 3], shift) 439 440 441 @test_util.with_c_api 442 class MomentsTest(test.TestCase): 443 444 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 445 # Method to compute moments of `x` wrt `axes`. 446 # 447 # This is exposed so WeightedMomentsTest can inherit the tests and 448 # assertions from MomentsTest; the extra_out_grads argument allows 449 # its inherited gradient tests to assert gradients against the 450 # weights as well as the input values. 451 452 return nn_impl.moments(x, axes, keep_dims=keep_dims) 453 454 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 455 with self.test_session(): 456 # shape = [batch, width, height, depth] 457 assert len(shape) == 4 458 459 x_numpy = np.random.normal(size=shape).astype(np.float32) 460 x = array_ops.placeholder(dtype, shape=[None] * len(shape)) 461 462 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 463 464 num_elements = np.prod([shape[i] for i in axes]) 465 466 ax = tuple(axes) 467 expected_mean = np.sum(x_numpy, axis=ax, 468 keepdims=keep_dims) / num_elements 469 expected_mean_squared = np.multiply(expected_mean, expected_mean) 470 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 471 axis=ax, 472 keepdims=keep_dims) / num_elements 473 expected_variance = expected_x_squared - expected_mean_squared 474 475 # Check that the moments are correct. 476 self.assertAllCloseAccordingToType( 477 expected_mean, mean.eval(feed_dict={x: x_numpy})) 478 self.assertAllCloseAccordingToType( 479 expected_variance, var.eval(feed_dict={x: x_numpy})) 480 481 def RunMomentTest(self, shape, axes, keep_dims, dtype): 482 with self.test_session(): 483 # shape = [batch, width, height, depth] 484 assert len(shape) == 4 485 486 x_numpy = np.random.normal(size=shape).astype(np.float32) 487 x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype) 488 489 # Compute the expected values at high precision since the method 490 # is prone to catastrophic cancellation: 491 x_numpy = x_numpy.astype(np.float128) 492 493 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 494 495 num_elements = np.prod([shape[i] for i in axes]) 496 497 ax = tuple(axes) 498 expected_mean = np.sum(x_numpy, axis=ax, 499 keepdims=keep_dims) / num_elements 500 expected_mean_squared = np.multiply(expected_mean, expected_mean) 501 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 502 axis=ax, 503 keepdims=keep_dims) / num_elements 504 expected_variance = expected_x_squared - expected_mean_squared 505 506 # Check that the moments are correct. 507 self.assertAllCloseAccordingToType(expected_mean, mean.eval()) 508 self.assertAllCloseAccordingToType(expected_variance, var.eval()) 509 510 def testBasic(self): 511 for keep_dims in [False, True]: 512 for dtype in [dtypes.float32, dtypes.float16]: 513 self.RunMomentTest( 514 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 515 self.RunMomentTestWithDynamicShape( 516 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 517 518 def testGlobalNormalization(self): 519 for keep_dims in [False, True]: 520 for dtype in [dtypes.float32, dtypes.float16]: 521 self.RunMomentTest( 522 shape=[2, 3, 5, 4], 523 axes=[0, 1, 2], 524 keep_dims=keep_dims, 525 dtype=dtype) 526 self.RunMomentTestWithDynamicShape( 527 shape=[2, 3, 5, 4], 528 axes=[0, 1, 2], 529 keep_dims=keep_dims, 530 dtype=dtype) 531 532 def testAxes(self): 533 for keep_dims in [False, True]: 534 for dtype in [dtypes.float32, dtypes.float16]: 535 self.RunMomentTest( 536 shape=[2, 3, 5, 4], 537 axes=[1, 2, 3], 538 keep_dims=keep_dims, 539 dtype=dtype) 540 self.RunMomentTestWithDynamicShape( 541 shape=[2, 3, 5, 4], 542 axes=[1, 2, 3], 543 keep_dims=keep_dims, 544 dtype=dtype) 545 546 def _testGlobalGradient(self, from_y="mean"): 547 with self.test_session(): 548 x_shape = [3, 5, 4, 2] 549 x_val = np.random.random_sample(x_shape).astype(np.float64) 550 x = constant_op.constant(x_val) 551 x.set_shape(x_shape) 552 553 axes = [0, 1, 2] 554 y_shape = [2] # Depth of x 555 556 inputs_to_compute_gradients_for = [x] 557 558 out_mean, out_var = self._unweighted_moments( 559 x, axes, extra_out_grads=inputs_to_compute_gradients_for) 560 if from_y == "mean": 561 y = out_mean 562 elif from_y == "var": 563 y = out_var 564 565 for (i, v) in enumerate(inputs_to_compute_gradients_for): 566 err = gradient_checker.compute_gradient_error(v, 567 v.get_shape().as_list(), 568 y, y_shape) 569 print("Moments %s gradient err vs input %d = %g" % (from_y, i, err)) 570 self.assertLess(err, 1e-11) 571 572 def testMeanGlobalGradient(self): 573 self._testGlobalGradient(from_y="mean") 574 575 def testVarGlobalGradient(self): 576 self._testGlobalGradient(from_y="var") 577 578 579 @test_util.with_c_api 580 class WeightedMomentsTest(MomentsTest): 581 """Tests for nn.weighted_moments. 582 583 Note that this test inherits from MomentsTest, inheriting all its 584 test methods! 585 586 It modifies MomentsTest in two ways: 587 588 a) By overriding _unweighted_moments, all the codepaths in 589 MomentsTest are executed, but with calls to tf.nn.moments() 590 replaced by calls to tf.nn.weighted_moments() with a constant 591 weight of 1. 592 593 b) By overriding RunMomentTest and RunMomentTestWithDynamicShape, 594 this test adds multiple additional calls to 595 RunWeightedMomentsTest() to exercise correctness with 596 non-constant weights and varying broadcasting situations. (It 597 also continues to call MomentsTest.Run(Weighted)?MomentsTest as 598 well.) 599 600 """ 601 602 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 603 weights = constant_op.constant(1, dtype=x.dtype) 604 if extra_out_grads is not None: 605 # We want to assert gradients WRT weights as well as X! 606 extra_out_grads.append(weights) 607 return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims) 608 609 def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False): 610 if not dynshapes: 611 super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims, 612 dtype) 613 else: 614 super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape, 615 axes, 616 keep_dims, 617 dtype) 618 619 # 1:1 weights and inputs 620 self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype) 621 622 # Various broadcasting combinations 623 for idx in range(len(shape)): 624 # try broadcasting weights in all positions 625 weight_shape = [1] * len(shape) 626 weight_shape[idx] = shape[idx] 627 628 self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype) 629 630 # Also try broadcasting with a suffix of length n 631 weight_shape = shape[-(idx + 1):] 632 self.RunWeightedMomentTest( 633 shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes) 634 635 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 636 self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True) 637 638 def RunWeightedMomentTest(self, 639 shape, 640 weights_shape, 641 axes, 642 keep_dims, 643 dtype, 644 dynshapes=False): 645 with self.test_session() as s: 646 x_numpy = np.random.normal(size=shape).astype(np.float32) 647 weights_numpy = np.absolute( # weights must be positive 648 np.random.normal( 649 size=weights_shape, loc=1.0).astype(np.float32)) 650 651 # Expand the numpy version to higher precision 652 x_numpy = x_numpy.astype(np.float128) 653 weights_numpy = weights_numpy.astype(np.float128) 654 655 x_shape = [None] * len(shape) if dynshapes else shape 656 weights_shape = ([None] * len(weights_shape) if dynshapes else 657 weights_shape) 658 659 x = array_ops.placeholder(dtype, shape=x_shape) 660 weights = array_ops.placeholder(dtype, shape=weights_shape) 661 662 mean, var = nn_impl.weighted_moments( 663 x, axes, weights, keep_dims=keep_dims) 664 665 ax = tuple(axes) 666 667 def _np_weighted_sum(v): 668 return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) 669 670 weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) 671 expected_mean = _np_weighted_sum(x_numpy) / weight_sum 672 expected_mean_squared = np.multiply(expected_mean, expected_mean) 673 expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / 674 weight_sum) 675 expected_variance = expected_x_squared - expected_mean_squared 676 677 mean_v, var_v = s.run([mean, var], 678 feed_dict={x: x_numpy, 679 weights: weights_numpy}) 680 681 self.assertAllCloseAccordingToType(expected_mean, mean_v) 682 self.assertAllCloseAccordingToType(expected_variance, var_v) 683 684 685 if __name__ == "__main__": 686 test.main() 687