Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/array2d.h"
     21 #include "tensorflow/compiler/xla/array4d.h"
     22 #include "tensorflow/compiler/xla/client/computation.h"
     23 #include "tensorflow/compiler/xla/client/computation_builder.h"
     24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     25 #include "tensorflow/compiler/xla/client/local_client.h"
     26 #include "tensorflow/compiler/xla/literal_util.h"
     27 #include "tensorflow/compiler/xla/reference_util.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/statusor.h"
     33 #include "tensorflow/compiler/xla/test.h"
     34 #include "tensorflow/compiler/xla/test_helpers.h"
     35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     38 #include "tensorflow/compiler/xla/tests/test_macros.h"
     39 #include "tensorflow/compiler/xla/tests/test_utils.h"
     40 #include "tensorflow/compiler/xla/util.h"
     41 #include "tensorflow/compiler/xla/xla_data.pb.h"
     42 #include "tensorflow/core/lib/math/math_util.h"
     43 #include "tensorflow/core/lib/strings/str_util.h"
     44 #include "tensorflow/core/platform/logging.h"
     45 #include "tensorflow/core/platform/test.h"
     46 #include "tensorflow/core/platform/types.h"
     47 
     48 namespace xla {
     49 namespace {
     50 
     51 class BatchNormalizationTest
     52     : public ClientLibraryTestBase,
     53       public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> {
     54  protected:
     55   BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
     56     mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam());
     57 
     58     Array2D<float> pz({
     59         // z0 z1
     60         {-1.0f, 4.1f},  // p0
     61         {2.0f, 4.1f},   // p1
     62         {5.0f, 4.4f},   // p2
     63     });
     64     input_array_.FillWithPZ(pz);
     65     input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
     66     CHECK_EQ(kSamples, input_array_.planes());
     67     CHECK_EQ(kZ, input_array_.depth());
     68     CHECK_EQ(kY, input_array_.height());
     69     CHECK_EQ(kY, input_array_.width());
     70   }
     71 
     72   static constexpr int64 kSamples = 3;
     73   static constexpr int64 kX = 1;
     74   static constexpr int64 kY = 1;
     75   static constexpr int64 kZ = 2;
     76 
     77   Array4D<float> input_array_;
     78   Literal input_literal_;
     79   const ErrorSpec error_spec_{0.001, 0.001};
     80 };
     81 
     82 // If testing the GPU backend, run the tests twice, with and without cudnn
     83 // batchnorm.  Otherwise, just run the tests once -- the value of this flag
     84 // doesn't matter.
     85 #ifdef XLA_TEST_BACKEND_GPU
     86 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
     87                         ::testing::Bool());
     88 #else
     89 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
     90                         ::testing::Values(false));
     91 #endif
     92 
     93 XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
     94   ComputationBuilder builder(client_, "subtract_in_z_one_sample");
     95   auto x = builder.ConstantLiteral(input_literal_);
     96   auto y = builder.ConstantR1<float>({3.14, 4.25});
     97   builder.Sub(x, y, /*broadcast_dimensions=*/{1});
     98 
     99   Array4D<float> expected(kSamples, kZ, kY, kX);
    100   Array2D<float> pz({
    101       {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
    102       {2.0f - 3.14f, 4.1f - 4.25f},   // p1
    103       {5.0f - 3.14f, 4.4f - 4.25f},   // p2
    104   });
    105   expected.FillWithPZ(pz);
    106   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
    107 }
    108 
    109 XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
    110   ComputationBuilder builder(client_, "square_tesseract_elementwise");
    111   auto x = builder.ConstantLiteral(input_literal_);
    112   builder.SquareF32(x);
    113 
    114   using tensorflow::MathUtil;
    115 
    116   Array4D<float> expected(kSamples, kZ, kY, kX);
    117   Array2D<float> expected_pz({
    118       {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
    119       {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
    120       {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
    121   });
    122   expected.FillWithPZ(expected_pz);
    123   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
    124 }
    125 
    126 XLA_TEST_P(BatchNormalizationTest, SumToZ) {
    127   ComputationBuilder builder(client_, "sum_to_z");
    128   auto input_activations = builder.ConstantLiteral(input_literal_);
    129   Computation add = CreateScalarAddComputation(F32, &builder);
    130   // Reduce all but the Z dimension.
    131   builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
    132                  {0, 2, 3});
    133 
    134   std::vector<float> expected = {6, 12.6};
    135   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
    136 }
    137 
    138 XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
    139   ComputationBuilder builder(client_, "square_and_reduce");
    140   auto input_activations = builder.ConstantLiteral(input_literal_);
    141   auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
    142   auto activation_deviations = builder.Sub(input_activations, set_means,
    143                                            /*broadcast_dimensions=*/{1});
    144   Computation add = CreateScalarAddComputation(F32, &builder);
    145   auto dev_squares = builder.SquareF32(activation_deviations);
    146   auto sum_of_squares = builder.Reduce(
    147       dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
    148 
    149   std::vector<float> expected = {18, 0.06};
    150   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
    151 }
    152 
    153 XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
    154   ComputationBuilder builder(client_, "variance_to_stddev");
    155   auto variance = builder.ConstantR1<float>({6.f, .02f});
    156   auto sqrt = builder.SqrtF32(variance);
    157 
    158   std::vector<float> expected = {2.44948974f, 0.14142136f};
    159   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
    160 }
    161 
    162 // Compare against a forward batch normalization example in the NN spec
    163 // reference.
    164 XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
    165   ComputationBuilder builder(client_, "batch_normalize_per_spec");
    166   auto input_activations =
    167       builder.CheckShape(builder.ConstantLiteral(input_literal_),
    168                          ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
    169   auto gamma = builder.ConstantR1<float>({1.0, 1.0});
    170   auto beta = builder.ConstantR1<float>({0.0, 0.0});
    171   Computation add = CreateScalarAddComputation(F32, &builder);
    172   // Reduce all dimensions except dimension 1.
    173   Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
    174   auto sum = builder.CheckShape(
    175       builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
    176                      /*dimensions_to_reduce=*/{0, 2, 3}),
    177       TwoElementVectorF32);
    178   auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
    179   auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
    180   auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
    181                                          ShapeUtil::ElementsIn(*sum_shape));
    182   auto set_means = builder.Div(sum, count);
    183 
    184   const float kEpsilon = 1e-9f;
    185   auto epsilon = builder.ConstantR0<float>(kEpsilon);
    186   auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
    187   auto activation_deviations = builder.Sub(input_activations, set_means,
    188                                            /*broadcast_dimensions=*/{1});
    189   auto dev_squares = builder.SquareF32(activation_deviations);
    190   auto sum_of_squares = builder.CheckShape(
    191       builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
    192                      /*dimensions_to_reduce=*/{0, 2, 3}),
    193       TwoElementVectorF32);
    194   auto variance = builder.Div(sum_of_squares, count);
    195   auto standard_deviation = builder.SqrtF32(variance);
    196   auto standard_deviation_above_epsilon = builder.CheckShape(
    197       builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
    198   auto gt_eps = builder.Select(standard_deviation_above_epsilon,
    199                                standard_deviation, epsilon2);
    200   auto normalization_factors = builder.ReciprocalF32(gt_eps);
    201   auto normalized_input_activations =
    202       builder.Mul(activation_deviations, normalization_factors,
    203                   /*broadcast_dimensions=*/{1});
    204   /* auto output_activations = */ builder.Add(
    205       builder.Mul(normalized_input_activations, gamma,
    206                   /*broadcast_dimensions=*/{1}),
    207       beta, /*broadcast_dimensions=*/{1});
    208 
    209   Array4D<float> expected(kSamples, kZ, kY, kX);
    210   Array2D<float> pz({
    211       {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
    212       {0.f, -.1f / std::sqrt(.02f)},
    213       {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
    214   });
    215   expected.FillWithPZ(pz);
    216 
    217   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
    218 }
    219 
    220 XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
    221   const int kFeatureIndex = 3;
    222   ComputationBuilder builder(client_, TestName());
    223 
    224   auto operand = builder.ConstantR4FromArray4D<float>(
    225       {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
    226 
    227   auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
    228 
    229   auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
    230 
    231   auto tuple = builder.BatchNormTraining(operand, scale, offset,
    232                                          /*epsilon=*/0.001, kFeatureIndex);
    233 
    234   auto expected = Literal::MakeTuple(
    235       {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
    236                                  {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
    237            .get(),
    238        Literal::CreateR1<float>({4, 5}).get(),
    239        Literal::CreateR1<float>({5, 5}).get()});
    240 
    241   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
    242 }
    243 
    244 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
    245   const int kFeatureIndex = 2;
    246   ComputationBuilder builder(client_, TestName());
    247 
    248   auto operand = builder.ConstantR4FromArray4D<float>(
    249       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
    250 
    251   auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
    252 
    253   auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
    254 
    255   auto tuple = builder.BatchNormTraining(operand, scale, offset,
    256                                          /*epsilon=*/0.001, kFeatureIndex);
    257 
    258   auto expected = Literal::MakeTuple(
    259       {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
    260                                  {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
    261            .get(),
    262        Literal::CreateR1<float>({4, 5}).get(),
    263        Literal::CreateR1<float>({5, 5}).get()});
    264 
    265   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
    266 }
    267 
    268 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
    269   // Use 0 dimension as feature, tests layout analyzer.
    270   const int kFeatureIndex = 0;
    271   ComputationBuilder builder(client_, TestName());
    272 
    273   ComputationDataHandle h0;
    274   auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
    275                                           /*parameter_number=*/0, "operand",
    276                                           &builder, &h0);
    277   ComputationDataHandle h1;
    278   auto scale =
    279       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
    280                                /*parameter_number=*/1, "scale", &builder, &h1);
    281   ComputationDataHandle h2;
    282   auto offset =
    283       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
    284                                /*parameter_number=*/2, "offset", &builder, &h2);
    285 
    286   auto tuple = builder.BatchNormTraining(h0, h1, h2,
    287                                          /*epsilon=*/1, kFeatureIndex);
    288 
    289   auto expected = Literal::MakeTuple(
    290       {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
    291            .get(),
    292        Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
    293        Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
    294 
    295   ComputeAndCompareTuple(&builder, *expected,
    296                          {operand.get(), scale.get(), offset.get()},
    297                          ErrorSpec(0.1));
    298 }
    299 
    300 XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
    301   // Test the correctness of choosing a large epsilon value.
    302   const int kFeatureIndex = 2;
    303   ComputationBuilder builder(client_, TestName());
    304 
    305   ComputationDataHandle h0;
    306   auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
    307                                           /*parameter_number=*/0, "operand",
    308                                           &builder, &h0);
    309   ComputationDataHandle h1;
    310   auto scale =
    311       CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
    312                                /*parameter_number=*/1, "scale", &builder, &h1);
    313   ComputationDataHandle h2;
    314   auto offset =
    315       CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
    316                                /*parameter_number=*/2, "offset", &builder, &h2);
    317 
    318   // var = 125, mean = 15, epsilon = -100
    319   auto tuple = builder.BatchNormTraining(h0, h1, h2,
    320                                          /*epsilon=*/-100, kFeatureIndex);
    321 
    322   auto expected = Literal::MakeTuple(
    323       {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
    324            .get(),
    325        Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
    326        Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
    327 
    328   ComputeAndCompareTuple(&builder, *expected,
    329                          {operand.get(), scale.get(), offset.get()},
    330                          ErrorSpec(0.1));
    331 }
    332 
    333 XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
    334   const int kFeatureIndex = 2;
    335   ComputationBuilder builder(client_, TestName());
    336 
    337   auto operand =
    338       builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
    339 
    340   auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
    341 
    342   auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
    343 
    344   auto var = builder.ConstantR1<float>({1.0f, 1.0f});
    345 
    346   auto grad_output = builder.ConstantR4FromArray4D<float>(
    347       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
    348 
    349   builder.BatchNormGrad(operand, scale, mean, var, grad_output,
    350                         /*epsilon=*/0.0, kFeatureIndex);
    351 
    352   auto expected = Literal::MakeTuple(
    353       {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
    354                                  {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
    355            .get(),
    356        Literal::CreateR1<float>({0, 0}).get(),
    357        Literal::CreateR1<float>({16, 20}).get()});
    358 
    359   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
    360 }
    361 
    362 struct BatchNormTestParam {
    363   std::vector<int64> bounds;
    364   int64 feature_index;
    365   float random_value_mean;
    366   float random_value_var;
    367   bool use_cudnn_batchnorm;
    368 
    369   friend ::std::ostream& operator<<(::std::ostream& os,
    370                                     const BatchNormTestParam& p) {
    371     os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
    372     os << "feature_index=" << p.feature_index << ", ";
    373     os << "random_value_mean=" << p.random_value_mean << ", ";
    374     os << "random_value_var=" << p.random_value_var;
    375 
    376     // Don't print use_cudnn_batchnorm when it's false, because most backends
    377     // never set it to true.
    378     if (p.use_cudnn_batchnorm) {
    379       os << ", use_cudnn_batchnorm=true";
    380     }
    381     return os;
    382   }
    383 };
    384 
    385 // Tests to test the fused operation of BatchNorm.
    386 class BatchNormTestManySizes
    387     : public ClientLibraryTestBase,
    388       public ::testing::WithParamInterface<BatchNormTestParam> {
    389  public:
    390   BatchNormTestManySizes() {
    391     mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(
    392         GetParam().use_cudnn_batchnorm);
    393   }
    394 };
    395 
    396 std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
    397   std::vector<BatchNormTestParam> params;
    398 
    399   auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index,
    400                           float random_value_mean, float random_value_var) {
    401     BatchNormTestParam p{bounds, feature_index, random_value_mean,
    402                          random_value_var, /*use_cudnn_batchnorm=*/false};
    403     params.push_back(p);
    404 
    405     // If testing the GPU backend, also run with cudnn batchnorm enabled.
    406 #ifdef XLA_TEST_BACKEND_GPU
    407     p.use_cudnn_batchnorm = true;
    408     params.push_back(p);
    409 #endif
    410   };
    411 
    412   add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
    413   add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
    414 
    415   add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
    416   add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
    417   add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
    418   add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
    419   add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
    420   add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
    421   add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
    422   add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
    423 
    424   add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
    425   add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
    426 
    427   // Feature on low dimension to trigger relayout, check that internal logical
    428   // to physical dimension calculation is correct after relayout.
    429   add_testcase({1, 2, 3, 4}, 0, 100, 100);
    430 
    431   // Zero-sized tensor.
    432   add_testcase({1, 0, 100, 42}, 0, 100, 100);
    433 
    434   return params;
    435 }
    436 
    437 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
    438                         ::testing::ValuesIn(BuildBatchNormTestParams()));
    439 
    440 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
    441   float epsilon = 0.001;
    442   ComputationBuilder builder(client_, TestName());
    443   const std::vector<int64>& bounds = GetParam().bounds;
    444   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
    445   input_array.FillRandom(GetParam().random_value_var,
    446                          GetParam().random_value_mean);
    447 
    448   const int64 feature_index = GetParam().feature_index;
    449   const int64 num_elements_per_feature =
    450       Product(bounds) / bounds[feature_index];
    451   const int64 feature_bound = bounds[feature_index];
    452   std::vector<float> offset(feature_bound, 1);
    453   std::vector<float> scale(feature_bound, 2);
    454 
    455   auto input_squared =
    456       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
    457   std::vector<int64> reduce_dims;
    458   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
    459     if (i != feature_index) {
    460       reduce_dims.push_back(i);
    461     }
    462   }
    463 
    464   auto sum =
    465       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
    466                                   [](float a, float b) { return a + b; });
    467 
    468   auto sum_squared =
    469       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
    470                                   [](float a, float b) { return a + b; });
    471 
    472   std::vector<float> mean(feature_bound);
    473 
    474   for (int64 i = 0; i < feature_bound; ++i) {
    475     mean[i] = sum[i] / num_elements_per_feature;
    476   }
    477 
    478   std::vector<float> mean_square(feature_bound);
    479   for (int64 i = 0; i < feature_bound; ++i) {
    480     mean_square[i] = mean[i] * mean[i];
    481   }
    482 
    483   std::vector<float> square_mean(feature_bound);
    484   for (int64 i = 0; i < feature_bound; ++i) {
    485     square_mean[i] = sum_squared[i] / num_elements_per_feature;
    486   }
    487 
    488   std::vector<float> var(feature_bound);
    489   for (int64 i = 0; i < feature_bound; ++i) {
    490     var[i] = square_mean[i] - mean_square[i];
    491   }
    492 
    493   Array4D<float> mean4D =
    494       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
    495   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
    496   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
    497   auto offset4D =
    498       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
    499 
    500   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
    501                                                 scale4D, offset4D, epsilon);
    502 
    503   auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
    504 
    505   auto offset_literal = Literal::CreateR1<float>(offset);
    506   auto scale_literal = Literal::CreateR1<float>(scale);
    507   auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
    508 
    509   auto input_activations =
    510       builder.Parameter(0, input_literal->shape(), "input");
    511   auto scale_activations =
    512       builder.Parameter(1, scale_literal->shape(), "offset");
    513   auto offset_activations =
    514       builder.Parameter(2, offset_literal->shape(), "scale");
    515 
    516   auto expected = Literal::MakeTuple({expected_normalized.get(),
    517                                       Literal::CreateR1<float>(mean).get(),
    518                                       Literal::CreateR1<float>(var).get()});
    519 
    520   std::unique_ptr<GlobalData> input_data =
    521       client_->TransferToServer(*input_literal).ConsumeValueOrDie();
    522   std::unique_ptr<GlobalData> scale_data =
    523       client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
    524   std::unique_ptr<GlobalData> offset_data =
    525       client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
    526 
    527   builder.BatchNormTraining(input_activations, scale_activations,
    528                             offset_activations, epsilon, feature_index);
    529 
    530   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
    531   // disables constant folding, but we want it enabled for our zero-sized tensor
    532   // testcase.
    533   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
    534   ComputeAndCompareTuple(
    535       &builder, *expected,
    536       {input_data.get(), scale_data.get(), offset_data.get()},
    537       ErrorSpec(0.01, 1));
    538 }
    539 
    540 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
    541   float epsilon = 0.001;
    542   ComputationBuilder builder(client_, TestName());
    543   const std::vector<int64>& bounds = GetParam().bounds;
    544   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
    545   input_array.FillRandom(GetParam().random_value_var,
    546                          GetParam().random_value_mean);
    547 
    548   const int64 feature_index = GetParam().feature_index;
    549   const int64 num_elements_per_feature =
    550       Product(bounds) / bounds[feature_index];
    551   const int64 feature_bound = bounds[feature_index];
    552   std::vector<float> offset(feature_bound, 1);
    553   std::vector<float> scale(feature_bound, 2);
    554 
    555   auto input_squared =
    556       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
    557   std::vector<int64> reduce_dims;
    558   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
    559     if (i != feature_index) {
    560       reduce_dims.push_back(i);
    561     }
    562   }
    563 
    564   auto sum =
    565       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
    566                                   [](float a, float b) { return a + b; });
    567 
    568   auto sum_squared =
    569       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
    570                                   [](float a, float b) { return a + b; });
    571 
    572   std::vector<float> mean(feature_bound);
    573 
    574   for (int64 i = 0; i < feature_bound; ++i) {
    575     mean[i] = sum[i] / num_elements_per_feature;
    576   }
    577 
    578   std::vector<float> mean_square(feature_bound);
    579   for (int64 i = 0; i < feature_bound; ++i) {
    580     mean_square[i] = mean[i] * mean[i];
    581   }
    582 
    583   std::vector<float> square_mean(feature_bound);
    584   for (int64 i = 0; i < feature_bound; ++i) {
    585     square_mean[i] = sum_squared[i] / num_elements_per_feature;
    586   }
    587 
    588   std::vector<float> var(feature_bound);
    589   for (int64 i = 0; i < feature_bound; ++i) {
    590     var[i] = square_mean[i] - mean_square[i];
    591   }
    592 
    593   Array4D<float> mean4D =
    594       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
    595   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
    596   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
    597   auto offset4D =
    598       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
    599 
    600   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
    601                                                 scale4D, offset4D, epsilon);
    602 
    603   auto offset_literal = Literal::CreateR1<float>(offset);
    604   auto scale_literal = Literal::CreateR1<float>(scale);
    605   auto mean_literal = Literal::CreateR1<float>(mean);
    606   auto var_literal = Literal::CreateR1<float>(var);
    607   auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
    608 
    609   auto input_activations =
    610       builder.Parameter(0, input_literal->shape(), "input");
    611   auto scale_activations =
    612       builder.Parameter(1, scale_literal->shape(), "offset");
    613   auto offset_activations =
    614       builder.Parameter(2, offset_literal->shape(), "scale");
    615   auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
    616   auto variance_activations =
    617       builder.Parameter(4, var_literal->shape(), "variance");
    618 
    619   Array4D<float> expected = normalized;
    620 
    621   std::unique_ptr<GlobalData> input_data =
    622       client_->TransferToServer(*input_literal).ConsumeValueOrDie();
    623   std::unique_ptr<GlobalData> scale_data =
    624       client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
    625   std::unique_ptr<GlobalData> offset_data =
    626       client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
    627   std::unique_ptr<GlobalData> mean_data =
    628       client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
    629   std::unique_ptr<GlobalData> variance_data =
    630       client_->TransferToServer(*var_literal).ConsumeValueOrDie();
    631 
    632   builder.BatchNormInference(input_activations, scale_activations,
    633                              offset_activations, mean_activations,
    634                              variance_activations, epsilon, feature_index);
    635 
    636   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
    637   // disables constant folding, but we want it enabled for our zero-sized tensor
    638   // testcase.
    639   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
    640 
    641   ComputeAndCompareR4<float>(
    642       &builder, expected,
    643       {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
    644        variance_data.get()},
    645       ErrorSpec(0.01, 1));
    646 }
    647 
    648 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
    649   float epsilon = 0.001;
    650   ComputationBuilder builder(client_, TestName());
    651   const std::vector<int64>& bounds = GetParam().bounds;
    652   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
    653   input_array.FillRandom(GetParam().random_value_var,
    654                          GetParam().random_value_mean);
    655 
    656   Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
    657   grad_output_array.FillRandom(GetParam().random_value_var,
    658                                GetParam().random_value_mean);
    659 
    660   const int64 feature_index = GetParam().feature_index;
    661   const int64 num_elements_per_feature =
    662       Product(bounds) / bounds[feature_index];
    663   const int64 feature_bound = bounds[feature_index];
    664   std::vector<float> scale(feature_bound, 2);
    665 
    666   auto input_squared =
    667       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
    668   std::vector<int64> reduce_dims;
    669   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
    670     if (i != feature_index) {
    671       reduce_dims.push_back(i);
    672     }
    673   }
    674 
    675   auto sum =
    676       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
    677                                   [](float a, float b) { return a + b; });
    678 
    679   auto sum_squared =
    680       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
    681                                   [](float a, float b) { return a + b; });
    682 
    683   std::vector<float> mean(feature_bound);
    684 
    685   for (int64 i = 0; i < feature_bound; ++i) {
    686     if (num_elements_per_feature > 0) {
    687       mean[i] = sum[i] / num_elements_per_feature;
    688     } else {
    689       mean[i] = 0;
    690     }
    691   }
    692 
    693   std::vector<float> mean_square(feature_bound);
    694   for (int64 i = 0; i < feature_bound; ++i) {
    695     mean_square[i] = mean[i] * mean[i];
    696   }
    697 
    698   std::vector<float> square_mean(feature_bound);
    699   for (int64 i = 0; i < feature_bound; ++i) {
    700     if (num_elements_per_feature > 0) {
    701       square_mean[i] = sum_squared[i] / num_elements_per_feature;
    702     } else {
    703       square_mean[i] = 0;
    704     }
    705   }
    706 
    707   std::vector<float> var(feature_bound);
    708   for (int64 i = 0; i < feature_bound; ++i) {
    709     var[i] = square_mean[i] - mean_square[i];
    710   }
    711 
    712   Array4D<float> mean4D =
    713       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
    714   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
    715   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
    716 
    717   auto var_add_epsilon = *ReferenceUtil::MapArray4D(
    718       var4D, [epsilon](float a) { return a + epsilon; });
    719 
    720   auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
    721       var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
    722 
    723   auto grad_output_times_var =
    724       *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
    725                                  [](float a, float b) { return a * b; });
    726 
    727   auto activation_shifted = *ReferenceUtil::MapArray4D(
    728       input_array, mean4D, [](float a, float b) { return a - b; });
    729 
    730   auto activation_shifted_times_grad_output =
    731       *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
    732                                  [](float a, float b) { return a * b; });
    733 
    734   auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
    735       activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
    736       [](float a, float b) { return a * b; });
    737 
    738   auto grad_scale = ReferenceUtil::Reduce4DTo1D(
    739       grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
    740       [](float a, float b) { return a + b; });
    741 
    742   auto grad_offset =
    743       ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
    744                                   [](float a, float b) { return a + b; });
    745 
    746   auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
    747       scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
    748 
    749   auto I1 = *ReferenceUtil::MapArray4D(
    750       grad_output_array, [&](float a) { return num_elements_per_feature * a; });
    751 
    752   auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
    753 
    754   // I3 = sum(output_grad * (activation - mean(activation)))
    755   auto I3 = *ReferenceUtil::Broadcast1DTo4D(
    756       ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
    757                                   /*init=*/0.0f, reduce_dims,
    758                                   [](float a, float b) { return a + b; }),
    759       bounds, feature_index);
    760 
    761   // I4 = (activation - mean(activation)) *
    762   //   sum(output_grad * (activation - mean(activation)))
    763   auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
    764                                        [](float a, float b) { return a * b; });
    765 
    766   // I5 = (activation - mean(activation)) *
    767   //   sum(output_grad * (activation - mean(activation))) / (variance +
    768   //   epsilon))
    769   auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
    770                                        [](float a, float b) { return a / b; });
    771 
    772   auto grad_activation = *ReferenceUtil::MapArray4D(
    773       I1, I2, [](float a, float b) { return a - b; });
    774 
    775   grad_activation = *ReferenceUtil::MapArray4D(
    776       grad_activation, I5, [](float a, float b) { return a - b; });
    777 
    778   grad_activation = *ReferenceUtil::MapArray4D(
    779       grad_activation, scale4D, [](float a, float b) { return a * b; });
    780 
    781   grad_activation = *ReferenceUtil::MapArray4D(
    782       grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
    783         if (num_elements_per_feature > 0) {
    784           return a * b / num_elements_per_feature;
    785         }
    786         return 0.f;
    787       });
    788 
    789   auto expected_grad_activation =
    790       Literal::CreateR4FromArray4D<float>(grad_activation);
    791 
    792   auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
    793   auto scale_literal = Literal::CreateR1<float>(scale);
    794   auto mean_literal = Literal::CreateR1<float>(mean);
    795   auto var_literal = Literal::CreateR1<float>(var);
    796   auto grad_output_literal =
    797       Literal::CreateR4FromArray4D<float>(grad_output_array);
    798 
    799   auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
    800   auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
    801   auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
    802   auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
    803   auto grad_output_parameter =
    804       builder.Parameter(4, grad_output_literal->shape(), "grad_output");
    805 
    806   std::unique_ptr<GlobalData> input_data =
    807       client_->TransferToServer(*input_literal).ConsumeValueOrDie();
    808   std::unique_ptr<GlobalData> scale_data =
    809       client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
    810   std::unique_ptr<GlobalData> mean_data =
    811       client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
    812   std::unique_ptr<GlobalData> var_data =
    813       client_->TransferToServer(*var_literal).ConsumeValueOrDie();
    814   std::unique_ptr<GlobalData> grad_output_data =
    815       client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
    816 
    817   auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
    818                                  mean_parameter, var_parameter,
    819                                  grad_output_parameter, epsilon, feature_index);
    820 
    821   auto expected =
    822       Literal::MakeTuple({expected_grad_activation.get(),
    823                           Literal::CreateR1<float>(grad_scale).get(),
    824                           Literal::CreateR1<float>(grad_offset).get()});
    825 
    826   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
    827   // disables constant folding, but we want it enabled for our zero-sized tensor
    828   // testcase.
    829   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
    830 
    831   ComputeAndCompareTuple(&builder, *expected,
    832                          {input_data.get(), scale_data.get(), mean_data.get(),
    833                           var_data.get(), grad_output_data.get()},
    834                          ErrorSpec(0.01, 1));
    835 }
    836 
    837 }  // namespace
    838 }  // namespace xla
    839