1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Functional tests for fused batch norm operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import gen_nn_ops 26 from tensorflow.python.ops import gradient_checker 27 from tensorflow.python.ops import nn 28 from tensorflow.python.platform import test 29 30 31 class FusedBatchNormTest(XLATestCase): 32 33 def _reference_training(self, x, scale, offset, epsilon, data_format): 34 if data_format != "NHWC": 35 raise ValueError("data_format must be NHWC, got %s." % data_format) 36 x_square = x * x 37 x_square_sum = np.sum(x_square, (0, 1, 2)) 38 x_sum = np.sum(x, axis=(0, 1, 2)) 39 element_count = np.size(x) / int(np.shape(x)[-1]) 40 mean = x_sum / element_count 41 var = x_square_sum / element_count - mean * mean 42 normalized = (x - mean) / np.sqrt(var + epsilon) 43 return (normalized * scale + offset), mean, var 44 45 def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): 46 # Use the following formulas to calculate gradients: 47 # grad_scale = 48 # sum(grad_y * (x - mean)) * rsqrt(var + epsilon) 49 # 50 # grad_offset = sum(output_y) 51 # 52 # grad_x = 53 # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - 54 # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) 55 if data_format != "NHWC": 56 raise ValueError("data_format must be NHWC, got %s." % data_format) 57 grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) - 58 (x - mean) * np.mean(grad_y * 59 (x - mean), axis=(0, 1, 2)) / 60 (var + epsilon)) / np.sqrt(var + epsilon) 61 grad_scale = np.sum( 62 grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) 63 grad_offset = np.sum(grad_y, axis=(0, 1, 2)) 64 return grad_x, grad_scale, grad_offset 65 66 def testInference(self): 67 channel = 3 68 x_shape = [2, 2, 6, channel] 69 scale_shape = [channel] 70 x_val = np.random.random_sample(x_shape).astype(np.float32) 71 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 72 73 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 74 data_format = "NHWC" 75 with self.test_session() as sess, self.test_scope(): 76 # To avoid constant folding 77 t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") 78 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 79 offset = array_ops.placeholder( 80 np.float32, shape=scale_shape, name="offset") 81 epsilon = 0.001 82 y_ref, mean_ref, var_ref = self._reference_training( 83 x_val, scale_val, offset_val, epsilon, data_format) 84 y, mean, variance = nn.fused_batch_norm( 85 t_val, 86 scale, 87 offset, 88 mean=mean_ref, 89 variance=var_ref, 90 epsilon=epsilon, 91 data_format=data_format, 92 is_training=False) 93 94 y_val, _, _ = sess.run( 95 [y, mean, 96 variance], {t_val: x_val, 97 scale: scale_val, 98 offset: offset_val}) 99 self.assertAllClose(y_val, y_ref, atol=1e-3) 100 101 def _testLearning(self, use_gradient_checker): 102 channel = 3 103 x_shape = [2, 2, 6, channel] 104 scale_shape = [channel] 105 x_val = np.random.random_sample(x_shape).astype(np.float32) 106 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 107 108 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 109 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 110 var_val = np.random.random_sample(scale_shape).astype(np.float32) 111 data_format = "NHWC" 112 with self.test_session() as sess, self.test_scope(): 113 # To avoid constant folding 114 t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") 115 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 116 offset = array_ops.placeholder( 117 np.float32, shape=scale_shape, name="offset") 118 epsilon = 0.001 119 y, mean, var = nn.fused_batch_norm( 120 t_val, 121 scale, 122 offset, 123 mean=None, 124 variance=None, 125 epsilon=epsilon, 126 data_format=data_format, 127 is_training=True) 128 # Check gradient. 129 if use_gradient_checker: 130 err = gradient_checker.compute_gradient_error( 131 t_val, 132 x_shape, 133 y, 134 x_shape, 135 extra_feed_dict={ 136 t_val: x_val, 137 scale: scale_val, 138 offset: offset_val 139 }) 140 self.assertLess(err, 1e-3) 141 142 y_val, mean_val, var_val = sess.run( 143 [y, mean, var], {t_val: x_val, 144 scale: scale_val, 145 offset: offset_val}) 146 y_ref, mean_ref, var_ref = self._reference_training( 147 x_val, scale_val, offset_val, epsilon, data_format) 148 self.assertAllClose(mean_val, mean_ref, atol=1e-3) 149 self.assertAllClose(y_val, y_ref, atol=1e-3) 150 self.assertAllClose(var_val, var_ref, atol=1e-3) 151 152 def testLearning(self): 153 self._testLearning(False) 154 155 def testLearningWithGradientChecker(self): 156 self._testLearning(True) 157 158 def testGradientTraining(self): 159 # TODO(b/64270657): Use gradient_checker here in addition to comparing with 160 # this reference implementation. 161 channel = 3 162 x_shape = [2, 2, 6, channel] 163 scale_shape = [channel] 164 grad_val = np.random.random_sample(x_shape).astype(np.float32) 165 x_val = np.random.random_sample(x_shape).astype(np.float32) 166 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 167 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 168 var_val = np.random.random_sample(scale_shape).astype(np.float32) 169 epsilon = 0.001 170 171 with self.test_session() as sess, self.test_scope(): 172 grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") 173 x = array_ops.placeholder(np.float32, shape=x_shape, name="x") 174 mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") 175 var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") 176 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 177 grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( 178 grad, x, scale, mean, var, data_format="NHWC", is_training=True) 179 180 grad_x_val, grad_scale_val, grad_offset_val = sess.run( 181 [grad_x, grad_scale, grad_offset], { 182 grad: grad_val, 183 x: x_val, 184 mean: mean_val, 185 var: var_val, 186 scale: scale_val 187 }) 188 189 grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( 190 x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC") 191 192 self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) 193 self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) 194 self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) 195 196 def testGradientInference(self): 197 # TODO(b/64270657): Use gradient_checker here in addition to comparing with 198 # this reference implementation. 199 channel = 3 200 x_shape = [2, 2, 6, channel] 201 scale_shape = [channel] 202 grad_val = np.random.random_sample(x_shape).astype(np.float32) 203 x_val = np.random.random_sample(x_shape).astype(np.float32) 204 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 205 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 206 var_val = np.random.random_sample(scale_shape).astype(np.float32) 207 208 with self.test_session() as sess, self.test_scope(): 209 grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") 210 x = array_ops.placeholder(np.float32, shape=x_shape, name="x") 211 mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") 212 var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") 213 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 214 with self.test_scope(): 215 out = gen_nn_ops.fused_batch_norm_grad( 216 grad, x, scale, mean, var, data_format="NHWC", is_training=False) 217 grad_x, grad_scale, grad_offset, _, _ = out 218 219 ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( 220 grad, x, scale, mean, var, data_format="NHWC", is_training=False) 221 222 grad_x_val, grad_scale_val, grad_offset_val, = sess.run( 223 [grad_x, grad_scale, grad_offset], { 224 grad: grad_val, 225 x: x_val, 226 mean: mean_val, 227 var: var_val, 228 scale: scale_val 229 }) 230 grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( 231 [ref_x, ref_scale, ref_offset], { 232 grad: grad_val, 233 x: x_val, 234 mean: mean_val, 235 var: var_val, 236 scale: scale_val 237 }) 238 239 self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) 240 self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) 241 self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) 242 243 244 if __name__ == "__main__": 245 test.main() 246