1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/array2d.h" 21 #include "tensorflow/compiler/xla/array4d.h" 22 #include "tensorflow/compiler/xla/client/computation.h" 23 #include "tensorflow/compiler/xla/client/computation_builder.h" 24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/compiler/xla/reference_util.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/test.h" 34 #include "tensorflow/compiler/xla/test_helpers.h" 35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 37 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 38 #include "tensorflow/compiler/xla/tests/test_macros.h" 39 #include "tensorflow/compiler/xla/tests/test_utils.h" 40 #include "tensorflow/compiler/xla/util.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/lib/math/math_util.h" 43 #include "tensorflow/core/lib/strings/str_util.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/test.h" 46 #include "tensorflow/core/platform/types.h" 47 48 namespace xla { 49 namespace { 50 51 class BatchNormalizationTest 52 : public ClientLibraryTestBase, 53 public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> { 54 protected: 55 BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { 56 mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam()); 57 58 Array2D<float> pz({ 59 // z0 z1 60 {-1.0f, 4.1f}, // p0 61 {2.0f, 4.1f}, // p1 62 {5.0f, 4.4f}, // p2 63 }); 64 input_array_.FillWithPZ(pz); 65 input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); 66 CHECK_EQ(kSamples, input_array_.planes()); 67 CHECK_EQ(kZ, input_array_.depth()); 68 CHECK_EQ(kY, input_array_.height()); 69 CHECK_EQ(kY, input_array_.width()); 70 } 71 72 static constexpr int64 kSamples = 3; 73 static constexpr int64 kX = 1; 74 static constexpr int64 kY = 1; 75 static constexpr int64 kZ = 2; 76 77 Array4D<float> input_array_; 78 Literal input_literal_; 79 const ErrorSpec error_spec_{0.001, 0.001}; 80 }; 81 82 // If testing the GPU backend, run the tests twice, with and without cudnn 83 // batchnorm. Otherwise, just run the tests once -- the value of this flag 84 // doesn't matter. 85 #ifdef XLA_TEST_BACKEND_GPU 86 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, 87 ::testing::Bool()); 88 #else 89 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, 90 ::testing::Values(false)); 91 #endif 92 93 XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { 94 ComputationBuilder builder(client_, "subtract_in_z_one_sample"); 95 auto x = builder.ConstantLiteral(input_literal_); 96 auto y = builder.ConstantR1<float>({3.14, 4.25}); 97 builder.Sub(x, y, /*broadcast_dimensions=*/{1}); 98 99 Array4D<float> expected(kSamples, kZ, kY, kX); 100 Array2D<float> pz({ 101 {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 102 {2.0f - 3.14f, 4.1f - 4.25f}, // p1 103 {5.0f - 3.14f, 4.4f - 4.25f}, // p2 104 }); 105 expected.FillWithPZ(pz); 106 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 107 } 108 109 XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { 110 ComputationBuilder builder(client_, "square_tesseract_elementwise"); 111 auto x = builder.ConstantLiteral(input_literal_); 112 builder.SquareF32(x); 113 114 using tensorflow::MathUtil; 115 116 Array4D<float> expected(kSamples, kZ, kY, kX); 117 Array2D<float> expected_pz({ 118 {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)}, 119 {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)}, 120 {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)}, 121 }); 122 expected.FillWithPZ(expected_pz); 123 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 124 } 125 126 XLA_TEST_P(BatchNormalizationTest, SumToZ) { 127 ComputationBuilder builder(client_, "sum_to_z"); 128 auto input_activations = builder.ConstantLiteral(input_literal_); 129 Computation add = CreateScalarAddComputation(F32, &builder); 130 // Reduce all but the Z dimension. 131 builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 132 {0, 2, 3}); 133 134 std::vector<float> expected = {6, 12.6}; 135 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 136 } 137 138 XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { 139 ComputationBuilder builder(client_, "square_and_reduce"); 140 auto input_activations = builder.ConstantLiteral(input_literal_); 141 auto set_means = builder.ConstantR1<float>({2.f, 4.2f}); 142 auto activation_deviations = builder.Sub(input_activations, set_means, 143 /*broadcast_dimensions=*/{1}); 144 Computation add = CreateScalarAddComputation(F32, &builder); 145 auto dev_squares = builder.SquareF32(activation_deviations); 146 auto sum_of_squares = builder.Reduce( 147 dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3}); 148 149 std::vector<float> expected = {18, 0.06}; 150 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 151 } 152 153 XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { 154 ComputationBuilder builder(client_, "variance_to_stddev"); 155 auto variance = builder.ConstantR1<float>({6.f, .02f}); 156 auto sqrt = builder.SqrtF32(variance); 157 158 std::vector<float> expected = {2.44948974f, 0.14142136f}; 159 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 160 } 161 162 // Compare against a forward batch normalization example in the NN spec 163 // reference. 164 XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { 165 ComputationBuilder builder(client_, "batch_normalize_per_spec"); 166 auto input_activations = 167 builder.CheckShape(builder.ConstantLiteral(input_literal_), 168 ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); 169 auto gamma = builder.ConstantR1<float>({1.0, 1.0}); 170 auto beta = builder.ConstantR1<float>({0.0, 0.0}); 171 Computation add = CreateScalarAddComputation(F32, &builder); 172 // Reduce all dimensions except dimension 1. 173 Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); 174 auto sum = builder.CheckShape( 175 builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 176 /*dimensions_to_reduce=*/{0, 2, 3}), 177 TwoElementVectorF32); 178 auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); 179 auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); 180 auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) / 181 ShapeUtil::ElementsIn(*sum_shape)); 182 auto set_means = builder.Div(sum, count); 183 184 const float kEpsilon = 1e-9f; 185 auto epsilon = builder.ConstantR0<float>(kEpsilon); 186 auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon}); 187 auto activation_deviations = builder.Sub(input_activations, set_means, 188 /*broadcast_dimensions=*/{1}); 189 auto dev_squares = builder.SquareF32(activation_deviations); 190 auto sum_of_squares = builder.CheckShape( 191 builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, 192 /*dimensions_to_reduce=*/{0, 2, 3}), 193 TwoElementVectorF32); 194 auto variance = builder.Div(sum_of_squares, count); 195 auto standard_deviation = builder.SqrtF32(variance); 196 auto standard_deviation_above_epsilon = builder.CheckShape( 197 builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); 198 auto gt_eps = builder.Select(standard_deviation_above_epsilon, 199 standard_deviation, epsilon2); 200 auto normalization_factors = builder.ReciprocalF32(gt_eps); 201 auto normalized_input_activations = 202 builder.Mul(activation_deviations, normalization_factors, 203 /*broadcast_dimensions=*/{1}); 204 /* auto output_activations = */ builder.Add( 205 builder.Mul(normalized_input_activations, gamma, 206 /*broadcast_dimensions=*/{1}), 207 beta, /*broadcast_dimensions=*/{1}); 208 209 Array4D<float> expected(kSamples, kZ, kY, kX); 210 Array2D<float> pz({ 211 {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, 212 {0.f, -.1f / std::sqrt(.02f)}, 213 {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, 214 }); 215 expected.FillWithPZ(pz); 216 217 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 218 } 219 220 XLA_TEST_P(BatchNormalizationTest, BasicTraining) { 221 const int kFeatureIndex = 3; 222 ComputationBuilder builder(client_, TestName()); 223 224 auto operand = builder.ConstantR4FromArray4D<float>( 225 {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); 226 227 auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 228 229 auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 230 231 auto tuple = builder.BatchNormTraining(operand, scale, offset, 232 /*epsilon=*/0.001, kFeatureIndex); 233 234 auto expected = Literal::MakeTuple( 235 {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, 236 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) 237 .get(), 238 Literal::CreateR1<float>({4, 5}).get(), 239 Literal::CreateR1<float>({5, 5}).get()}); 240 241 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 242 } 243 244 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { 245 const int kFeatureIndex = 2; 246 ComputationBuilder builder(client_, TestName()); 247 248 auto operand = builder.ConstantR4FromArray4D<float>( 249 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 250 251 auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 252 253 auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 254 255 auto tuple = builder.BatchNormTraining(operand, scale, offset, 256 /*epsilon=*/0.001, kFeatureIndex); 257 258 auto expected = Literal::MakeTuple( 259 {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, 260 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) 261 .get(), 262 Literal::CreateR1<float>({4, 5}).get(), 263 Literal::CreateR1<float>({5, 5}).get()}); 264 265 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 266 } 267 268 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { 269 // Use 0 dimension as feature, tests layout analyzer. 270 const int kFeatureIndex = 0; 271 ComputationBuilder builder(client_, TestName()); 272 273 ComputationDataHandle h0; 274 auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f), 275 /*parameter_number=*/0, "operand", 276 &builder, &h0); 277 ComputationDataHandle h1; 278 auto scale = 279 CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 280 /*parameter_number=*/1, "scale", &builder, &h1); 281 ComputationDataHandle h2; 282 auto offset = 283 CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 284 /*parameter_number=*/2, "offset", &builder, &h2); 285 286 auto tuple = builder.BatchNormTraining(h0, h1, h2, 287 /*epsilon=*/1, kFeatureIndex); 288 289 auto expected = Literal::MakeTuple( 290 {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)) 291 .get(), 292 Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(), 293 Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()}); 294 295 ComputeAndCompareTuple(&builder, *expected, 296 {operand.get(), scale.get(), offset.get()}, 297 ErrorSpec(0.1)); 298 } 299 300 XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { 301 // Test the correctness of choosing a large epsilon value. 302 const int kFeatureIndex = 2; 303 ComputationBuilder builder(client_, TestName()); 304 305 ComputationDataHandle h0; 306 auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, 307 /*parameter_number=*/0, "operand", 308 &builder, &h0); 309 ComputationDataHandle h1; 310 auto scale = 311 CreateR1Parameter<float>(std::vector<float>(1, 1.0f), 312 /*parameter_number=*/1, "scale", &builder, &h1); 313 ComputationDataHandle h2; 314 auto offset = 315 CreateR1Parameter<float>(std::vector<float>(1, 0.0f), 316 /*parameter_number=*/2, "offset", &builder, &h2); 317 318 // var = 125, mean = 15, epsilon = -100 319 auto tuple = builder.BatchNormTraining(h0, h1, h2, 320 /*epsilon=*/-100, kFeatureIndex); 321 322 auto expected = Literal::MakeTuple( 323 {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) 324 .get(), 325 Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(), 326 Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()}); 327 328 ComputeAndCompareTuple(&builder, *expected, 329 {operand.get(), scale.get(), offset.get()}, 330 ErrorSpec(0.1)); 331 } 332 333 XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { 334 const int kFeatureIndex = 2; 335 ComputationBuilder builder(client_, TestName()); 336 337 auto operand = 338 builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f)); 339 340 auto scale = builder.ConstantR1<float>({1.0f, 1.0f}); 341 342 auto mean = builder.ConstantR1<float>({0.0f, 0.0f}); 343 344 auto var = builder.ConstantR1<float>({1.0f, 1.0f}); 345 346 auto grad_output = builder.ConstantR4FromArray4D<float>( 347 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 348 349 builder.BatchNormGrad(operand, scale, mean, var, grad_output, 350 /*epsilon=*/0.0, kFeatureIndex); 351 352 auto expected = Literal::MakeTuple( 353 {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, 354 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) 355 .get(), 356 Literal::CreateR1<float>({0, 0}).get(), 357 Literal::CreateR1<float>({16, 20}).get()}); 358 359 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 360 } 361 362 struct BatchNormTestParam { 363 std::vector<int64> bounds; 364 int64 feature_index; 365 float random_value_mean; 366 float random_value_var; 367 bool use_cudnn_batchnorm; 368 369 friend ::std::ostream& operator<<(::std::ostream& os, 370 const BatchNormTestParam& p) { 371 os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; 372 os << "feature_index=" << p.feature_index << ", "; 373 os << "random_value_mean=" << p.random_value_mean << ", "; 374 os << "random_value_var=" << p.random_value_var; 375 376 // Don't print use_cudnn_batchnorm when it's false, because most backends 377 // never set it to true. 378 if (p.use_cudnn_batchnorm) { 379 os << ", use_cudnn_batchnorm=true"; 380 } 381 return os; 382 } 383 }; 384 385 // Tests to test the fused operation of BatchNorm. 386 class BatchNormTestManySizes 387 : public ClientLibraryTestBase, 388 public ::testing::WithParamInterface<BatchNormTestParam> { 389 public: 390 BatchNormTestManySizes() { 391 mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm( 392 GetParam().use_cudnn_batchnorm); 393 } 394 }; 395 396 std::vector<BatchNormTestParam> BuildBatchNormTestParams() { 397 std::vector<BatchNormTestParam> params; 398 399 auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index, 400 float random_value_mean, float random_value_var) { 401 BatchNormTestParam p{bounds, feature_index, random_value_mean, 402 random_value_var, /*use_cudnn_batchnorm=*/false}; 403 params.push_back(p); 404 405 // If testing the GPU backend, also run with cudnn batchnorm enabled. 406 #ifdef XLA_TEST_BACKEND_GPU 407 p.use_cudnn_batchnorm = true; 408 params.push_back(p); 409 #endif 410 }; 411 412 add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f); 413 add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f); 414 415 add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f); 416 add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f); 417 add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f); 418 add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f); 419 add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f); 420 add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f); 421 add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f); 422 add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f); 423 424 add_testcase({24, 129, 1, 2}, 2, 10000, 10000); 425 add_testcase({24, 129, 1, 2}, 3, 10000, 10000); 426 427 // Feature on low dimension to trigger relayout, check that internal logical 428 // to physical dimension calculation is correct after relayout. 429 add_testcase({1, 2, 3, 4}, 0, 100, 100); 430 431 // Zero-sized tensor. 432 add_testcase({1, 0, 100, 42}, 0, 100, 100); 433 434 return params; 435 } 436 437 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, 438 ::testing::ValuesIn(BuildBatchNormTestParams())); 439 440 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { 441 float epsilon = 0.001; 442 ComputationBuilder builder(client_, TestName()); 443 const std::vector<int64>& bounds = GetParam().bounds; 444 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 445 input_array.FillRandom(GetParam().random_value_var, 446 GetParam().random_value_mean); 447 448 const int64 feature_index = GetParam().feature_index; 449 const int64 num_elements_per_feature = 450 Product(bounds) / bounds[feature_index]; 451 const int64 feature_bound = bounds[feature_index]; 452 std::vector<float> offset(feature_bound, 1); 453 std::vector<float> scale(feature_bound, 2); 454 455 auto input_squared = 456 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 457 std::vector<int64> reduce_dims; 458 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 459 if (i != feature_index) { 460 reduce_dims.push_back(i); 461 } 462 } 463 464 auto sum = 465 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 466 [](float a, float b) { return a + b; }); 467 468 auto sum_squared = 469 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 470 [](float a, float b) { return a + b; }); 471 472 std::vector<float> mean(feature_bound); 473 474 for (int64 i = 0; i < feature_bound; ++i) { 475 mean[i] = sum[i] / num_elements_per_feature; 476 } 477 478 std::vector<float> mean_square(feature_bound); 479 for (int64 i = 0; i < feature_bound; ++i) { 480 mean_square[i] = mean[i] * mean[i]; 481 } 482 483 std::vector<float> square_mean(feature_bound); 484 for (int64 i = 0; i < feature_bound; ++i) { 485 square_mean[i] = sum_squared[i] / num_elements_per_feature; 486 } 487 488 std::vector<float> var(feature_bound); 489 for (int64 i = 0; i < feature_bound; ++i) { 490 var[i] = square_mean[i] - mean_square[i]; 491 } 492 493 Array4D<float> mean4D = 494 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 495 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 496 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 497 auto offset4D = 498 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 499 500 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 501 scale4D, offset4D, epsilon); 502 503 auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized); 504 505 auto offset_literal = Literal::CreateR1<float>(offset); 506 auto scale_literal = Literal::CreateR1<float>(scale); 507 auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 508 509 auto input_activations = 510 builder.Parameter(0, input_literal->shape(), "input"); 511 auto scale_activations = 512 builder.Parameter(1, scale_literal->shape(), "offset"); 513 auto offset_activations = 514 builder.Parameter(2, offset_literal->shape(), "scale"); 515 516 auto expected = Literal::MakeTuple({expected_normalized.get(), 517 Literal::CreateR1<float>(mean).get(), 518 Literal::CreateR1<float>(var).get()}); 519 520 std::unique_ptr<GlobalData> input_data = 521 client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 522 std::unique_ptr<GlobalData> scale_data = 523 client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 524 std::unique_ptr<GlobalData> offset_data = 525 client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 526 527 builder.BatchNormTraining(input_activations, scale_activations, 528 offset_activations, epsilon, feature_index); 529 530 // Run all HLO passes during this test. In particular, ClientLibraryTestBase 531 // disables constant folding, but we want it enabled for our zero-sized tensor 532 // testcase. 533 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 534 ComputeAndCompareTuple( 535 &builder, *expected, 536 {input_data.get(), scale_data.get(), offset_data.get()}, 537 ErrorSpec(0.01, 1)); 538 } 539 540 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { 541 float epsilon = 0.001; 542 ComputationBuilder builder(client_, TestName()); 543 const std::vector<int64>& bounds = GetParam().bounds; 544 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 545 input_array.FillRandom(GetParam().random_value_var, 546 GetParam().random_value_mean); 547 548 const int64 feature_index = GetParam().feature_index; 549 const int64 num_elements_per_feature = 550 Product(bounds) / bounds[feature_index]; 551 const int64 feature_bound = bounds[feature_index]; 552 std::vector<float> offset(feature_bound, 1); 553 std::vector<float> scale(feature_bound, 2); 554 555 auto input_squared = 556 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 557 std::vector<int64> reduce_dims; 558 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 559 if (i != feature_index) { 560 reduce_dims.push_back(i); 561 } 562 } 563 564 auto sum = 565 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 566 [](float a, float b) { return a + b; }); 567 568 auto sum_squared = 569 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 570 [](float a, float b) { return a + b; }); 571 572 std::vector<float> mean(feature_bound); 573 574 for (int64 i = 0; i < feature_bound; ++i) { 575 mean[i] = sum[i] / num_elements_per_feature; 576 } 577 578 std::vector<float> mean_square(feature_bound); 579 for (int64 i = 0; i < feature_bound; ++i) { 580 mean_square[i] = mean[i] * mean[i]; 581 } 582 583 std::vector<float> square_mean(feature_bound); 584 for (int64 i = 0; i < feature_bound; ++i) { 585 square_mean[i] = sum_squared[i] / num_elements_per_feature; 586 } 587 588 std::vector<float> var(feature_bound); 589 for (int64 i = 0; i < feature_bound; ++i) { 590 var[i] = square_mean[i] - mean_square[i]; 591 } 592 593 Array4D<float> mean4D = 594 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 595 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 596 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 597 auto offset4D = 598 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 599 600 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 601 scale4D, offset4D, epsilon); 602 603 auto offset_literal = Literal::CreateR1<float>(offset); 604 auto scale_literal = Literal::CreateR1<float>(scale); 605 auto mean_literal = Literal::CreateR1<float>(mean); 606 auto var_literal = Literal::CreateR1<float>(var); 607 auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 608 609 auto input_activations = 610 builder.Parameter(0, input_literal->shape(), "input"); 611 auto scale_activations = 612 builder.Parameter(1, scale_literal->shape(), "offset"); 613 auto offset_activations = 614 builder.Parameter(2, offset_literal->shape(), "scale"); 615 auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); 616 auto variance_activations = 617 builder.Parameter(4, var_literal->shape(), "variance"); 618 619 Array4D<float> expected = normalized; 620 621 std::unique_ptr<GlobalData> input_data = 622 client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 623 std::unique_ptr<GlobalData> scale_data = 624 client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 625 std::unique_ptr<GlobalData> offset_data = 626 client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 627 std::unique_ptr<GlobalData> mean_data = 628 client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 629 std::unique_ptr<GlobalData> variance_data = 630 client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 631 632 builder.BatchNormInference(input_activations, scale_activations, 633 offset_activations, mean_activations, 634 variance_activations, epsilon, feature_index); 635 636 // Run all HLO passes during this test. In particular, ClientLibraryTestBase 637 // disables constant folding, but we want it enabled for our zero-sized tensor 638 // testcase. 639 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 640 641 ComputeAndCompareR4<float>( 642 &builder, expected, 643 {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), 644 variance_data.get()}, 645 ErrorSpec(0.01, 1)); 646 } 647 648 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { 649 float epsilon = 0.001; 650 ComputationBuilder builder(client_, TestName()); 651 const std::vector<int64>& bounds = GetParam().bounds; 652 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 653 input_array.FillRandom(GetParam().random_value_var, 654 GetParam().random_value_mean); 655 656 Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]); 657 grad_output_array.FillRandom(GetParam().random_value_var, 658 GetParam().random_value_mean); 659 660 const int64 feature_index = GetParam().feature_index; 661 const int64 num_elements_per_feature = 662 Product(bounds) / bounds[feature_index]; 663 const int64 feature_bound = bounds[feature_index]; 664 std::vector<float> scale(feature_bound, 2); 665 666 auto input_squared = 667 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 668 std::vector<int64> reduce_dims; 669 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 670 if (i != feature_index) { 671 reduce_dims.push_back(i); 672 } 673 } 674 675 auto sum = 676 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 677 [](float a, float b) { return a + b; }); 678 679 auto sum_squared = 680 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 681 [](float a, float b) { return a + b; }); 682 683 std::vector<float> mean(feature_bound); 684 685 for (int64 i = 0; i < feature_bound; ++i) { 686 if (num_elements_per_feature > 0) { 687 mean[i] = sum[i] / num_elements_per_feature; 688 } else { 689 mean[i] = 0; 690 } 691 } 692 693 std::vector<float> mean_square(feature_bound); 694 for (int64 i = 0; i < feature_bound; ++i) { 695 mean_square[i] = mean[i] * mean[i]; 696 } 697 698 std::vector<float> square_mean(feature_bound); 699 for (int64 i = 0; i < feature_bound; ++i) { 700 if (num_elements_per_feature > 0) { 701 square_mean[i] = sum_squared[i] / num_elements_per_feature; 702 } else { 703 square_mean[i] = 0; 704 } 705 } 706 707 std::vector<float> var(feature_bound); 708 for (int64 i = 0; i < feature_bound; ++i) { 709 var[i] = square_mean[i] - mean_square[i]; 710 } 711 712 Array4D<float> mean4D = 713 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 714 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 715 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 716 717 auto var_add_epsilon = *ReferenceUtil::MapArray4D( 718 var4D, [epsilon](float a) { return a + epsilon; }); 719 720 auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 721 var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); 722 723 auto grad_output_times_var = 724 *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, 725 [](float a, float b) { return a * b; }); 726 727 auto activation_shifted = *ReferenceUtil::MapArray4D( 728 input_array, mean4D, [](float a, float b) { return a - b; }); 729 730 auto activation_shifted_times_grad_output = 731 *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted, 732 [](float a, float b) { return a * b; }); 733 734 auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D( 735 activation_shifted_times_grad_output, rsqrt_var_add_epsilon, 736 [](float a, float b) { return a * b; }); 737 738 auto grad_scale = ReferenceUtil::Reduce4DTo1D( 739 grad_scale_before_reduction, /*init=*/0.0f, reduce_dims, 740 [](float a, float b) { return a + b; }); 741 742 auto grad_offset = 743 ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims, 744 [](float a, float b) { return a + b; }); 745 746 auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 747 scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; }); 748 749 auto I1 = *ReferenceUtil::MapArray4D( 750 grad_output_array, [&](float a) { return num_elements_per_feature * a; }); 751 752 auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index); 753 754 // I3 = sum(output_grad * (activation - mean(activation))) 755 auto I3 = *ReferenceUtil::Broadcast1DTo4D( 756 ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output, 757 /*init=*/0.0f, reduce_dims, 758 [](float a, float b) { return a + b; }), 759 bounds, feature_index); 760 761 // I4 = (activation - mean(activation)) * 762 // sum(output_grad * (activation - mean(activation))) 763 auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted, 764 [](float a, float b) { return a * b; }); 765 766 // I5 = (activation - mean(activation)) * 767 // sum(output_grad * (activation - mean(activation))) / (variance + 768 // epsilon)) 769 auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon, 770 [](float a, float b) { return a / b; }); 771 772 auto grad_activation = *ReferenceUtil::MapArray4D( 773 I1, I2, [](float a, float b) { return a - b; }); 774 775 grad_activation = *ReferenceUtil::MapArray4D( 776 grad_activation, I5, [](float a, float b) { return a - b; }); 777 778 grad_activation = *ReferenceUtil::MapArray4D( 779 grad_activation, scale4D, [](float a, float b) { return a * b; }); 780 781 grad_activation = *ReferenceUtil::MapArray4D( 782 grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) { 783 if (num_elements_per_feature > 0) { 784 return a * b / num_elements_per_feature; 785 } 786 return 0.f; 787 }); 788 789 auto expected_grad_activation = 790 Literal::CreateR4FromArray4D<float>(grad_activation); 791 792 auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 793 auto scale_literal = Literal::CreateR1<float>(scale); 794 auto mean_literal = Literal::CreateR1<float>(mean); 795 auto var_literal = Literal::CreateR1<float>(var); 796 auto grad_output_literal = 797 Literal::CreateR4FromArray4D<float>(grad_output_array); 798 799 auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); 800 auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); 801 auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); 802 auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); 803 auto grad_output_parameter = 804 builder.Parameter(4, grad_output_literal->shape(), "grad_output"); 805 806 std::unique_ptr<GlobalData> input_data = 807 client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 808 std::unique_ptr<GlobalData> scale_data = 809 client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 810 std::unique_ptr<GlobalData> mean_data = 811 client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 812 std::unique_ptr<GlobalData> var_data = 813 client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 814 std::unique_ptr<GlobalData> grad_output_data = 815 client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); 816 817 auto t = builder.BatchNormGrad(input_parameter, scale_parameter, 818 mean_parameter, var_parameter, 819 grad_output_parameter, epsilon, feature_index); 820 821 auto expected = 822 Literal::MakeTuple({expected_grad_activation.get(), 823 Literal::CreateR1<float>(grad_scale).get(), 824 Literal::CreateR1<float>(grad_offset).get()}); 825 826 // Run all HLO passes during this test. In particular, ClientLibraryTestBase 827 // disables constant folding, but we want it enabled for our zero-sized tensor 828 // testcase. 829 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 830 831 ComputeAndCompareTuple(&builder, *expected, 832 {input_data.get(), scale_data.get(), mean_data.get(), 833 var_data.get(), grad_output_data.get()}, 834 ErrorSpec(0.01, 1)); 835 } 836 837 } // namespace 838 } // namespace xla 839