Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 #include <limits>
     18 #include <memory>
     19 #include <numeric>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/array2d.h"
     23 #include "tensorflow/compiler/xla/array3d.h"
     24 #include "tensorflow/compiler/xla/array4d.h"
     25 #include "tensorflow/compiler/xla/client/computation_builder.h"
     26 #include "tensorflow/compiler/xla/client/global_data.h"
     27 #include "tensorflow/compiler/xla/client/local_client.h"
     28 #include "tensorflow/compiler/xla/layout_util.h"
     29 #include "tensorflow/compiler/xla/literal_util.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/compiler/xla/test.h"
     32 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     34 #include "tensorflow/compiler/xla/tests/test_macros.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/xla_data.pb.h"
     37 #include "tensorflow/core/lib/core/casts.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace xla {
     41 namespace {
     42 
     43 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
     44  public:
     45   ErrorSpec error_spec_{0.0001, 0.0001};
     46 };
     47 
     48 class ArrayElementwiseOpTestParamCount
     49     : public ArrayElementwiseOpTest,
     50       public ::testing::WithParamInterface<int> {};
     51 
     52 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
     53   ComputationBuilder builder(client_, TestName());
     54   auto a = builder.ConstantR1<float>({});
     55   auto result = builder.Neg(a);
     56 
     57   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
     58 }
     59 
     60 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
     61   ComputationBuilder builder(client_, TestName());
     62   auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
     63   auto result = builder.Neg(a);
     64 
     65   ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
     66                              error_spec_);
     67 }
     68 
     69 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
     70   ComputationBuilder builder(client_, TestName());
     71   auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
     72                                       std::numeric_limits<int32>::min(),
     73                                       std::numeric_limits<int32>::max()});
     74   auto result = builder.Neg(a);
     75 
     76   // -min == min for int32 due to an overflow. In C++ it is undefined behavior
     77   // to do this calculation. For XLA we have not specified that, so it
     78   // ought to work.
     79   ComputeAndCompareR1<int32>(&builder,
     80                              {1, 0, -1, -324, std::numeric_limits<int32>::min(),
     81                               -std::numeric_limits<int32>::max()},
     82                              {});
     83 }
     84 
     85 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
     86   ComputationBuilder builder(client_, TestName());
     87   auto a = builder.ConstantR1<complex64>({});
     88   auto result = builder.Neg(a);
     89 
     90   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
     91 }
     92 
     93 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
     94   ComputationBuilder builder(client_, TestName());
     95   auto a = builder.ConstantR1<complex64>(
     96       {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
     97   auto result = builder.Neg(a);
     98 
     99   ComputeAndCompareR1<complex64>(
    100       &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
    101       {}, error_spec_);
    102 }
    103 
    104 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
    105   ComputationBuilder builder(client_, TestName());
    106   auto a = builder.ConstantR1<float>({});
    107   auto result = builder.IsFinite(a);
    108 
    109   ComputeAndCompareR1<bool>(&builder, {}, {});
    110 }
    111 
    112 // A non-canonical quiet NaN value.
    113 static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234);
    114 
    115 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
    116   ComputationBuilder builder(client_, TestName());
    117   auto result = builder.IsFinite(builder.ConstantR0<float>(NAN));
    118   ComputeAndCompareR0<bool>(&builder, false, {});
    119 
    120   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
    121   auto result_non_canonical =
    122       builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN));
    123   ComputeAndCompareR0<bool>(&builder, false, {});
    124 
    125   const float inf = std::numeric_limits<float>::infinity();
    126   auto result_inf = builder.IsFinite(builder.ConstantR0<float>(inf));
    127   ComputeAndCompareR0<bool>(&builder, false, {});
    128 
    129   auto result_neg_inf = builder.IsFinite(builder.ConstantR0<float>(-inf));
    130   ComputeAndCompareR0<bool>(&builder, false, {});
    131 
    132   auto result_zero = builder.IsFinite(builder.ConstantR0<float>(0.0f));
    133   ComputeAndCompareR0<bool>(&builder, true, {});
    134 }
    135 
    136 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
    137   ComputationBuilder builder(client_, TestName());
    138   const float inf = std::numeric_limits<float>::infinity();
    139   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
    140   auto a = builder.ConstantR1<float>(
    141       {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
    142   auto result = builder.IsFinite(a);
    143 
    144   ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
    145                             {});
    146 }
    147 
    148 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
    149   ComputationBuilder builder(client_, TestName());
    150   auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
    151   auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
    152   auto add = builder.Add(a, b);
    153 
    154   ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
    155                              error_spec_);
    156 }
    157 
    158 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
    159   ComputationBuilder builder(client_, TestName());
    160   auto a = builder.ConstantR1<float>({});
    161   auto b = builder.ConstantR1<float>({});
    162   auto add = builder.Add(a, b);
    163 
    164   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    165 }
    166 
    167 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
    168   ComputationBuilder builder(client_, TestName());
    169   auto a = builder.ConstantR1<complex64>(
    170       {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
    171   auto b = builder.ConstantR1<complex64>(
    172       {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
    173   auto add = builder.Add(a, b);
    174 
    175   ComputeAndCompareR1<complex64>(
    176       &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
    177       error_spec_);
    178 }
    179 
    180 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
    181   ComputationBuilder builder(client_, TestName());
    182   auto a = builder.ConstantR1<complex64>({});
    183   auto b = builder.ConstantR1<complex64>({});
    184   auto add = builder.Add(a, b);
    185 
    186   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
    187 }
    188 
    189 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
    190   const int count = GetParam();
    191   ComputationBuilder builder(client_, TestName());
    192   std::vector<float> a_values;
    193   std::vector<float> b_values;
    194   for (int i = 0; i < count; ++i) {
    195     a_values.push_back(i / static_cast<float>(count));
    196     b_values.push_back(2 * i / static_cast<float>(count + 2));
    197   }
    198 
    199   std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({a_values});
    200   std::unique_ptr<GlobalData> a_data =
    201       client_->TransferToServer(*a_literal).ConsumeValueOrDie();
    202   auto a_constant = builder.ConstantR1<float>(a_values);
    203   auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
    204 
    205   std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({b_values});
    206   std::unique_ptr<GlobalData> b_data =
    207       client_->TransferToServer(*b_literal).ConsumeValueOrDie();
    208   auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param");
    209   auto b_param = builder.ConstantR1<float>(b_values);
    210 
    211   auto sum1 = builder.Add(a_constant, b_constant);
    212   auto sum2 = builder.Add(a_constant, b_param);
    213   auto sum3 = builder.Add(a_param, b_constant);
    214   auto sum4 = builder.Add(a_param, b_param);
    215 
    216   auto sum = builder.Add(sum1, sum2);
    217   sum = builder.Add(sum, sum3);
    218   sum = builder.Add(sum, sum4);
    219 
    220   std::vector<float> expected;
    221   for (int64 i = 0; i < count; ++i) {
    222     expected.push_back(4 * (a_values[i] + b_values[i]));
    223   }
    224 
    225   ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
    226                              error_spec_);
    227 }
    228 
    229 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
    230   ComputationBuilder builder(client_, TestName());
    231   auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
    232   auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
    233   auto add = builder.Sub(a, b);
    234 
    235   ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
    236                              {}, error_spec_);
    237 }
    238 
    239 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
    240   ComputationBuilder builder(client_, TestName());
    241   auto a = builder.ConstantR1<float>({});
    242   auto b = builder.ConstantR1<float>({});
    243   auto add = builder.Sub(a, b);
    244 
    245   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    246 }
    247 
    248 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
    249   ComputationBuilder builder(client_, TestName());
    250   auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
    251   auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
    252   auto add = builder.Sub(a, b);
    253 
    254   ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
    255 }
    256 
    257 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
    258   ComputationBuilder builder(client_, TestName());
    259   auto a = builder.ConstantR1<int32>({});
    260   auto b = builder.ConstantR1<int32>({});
    261   auto add = builder.Sub(a, b);
    262 
    263   ComputeAndCompareR1<int32>(&builder, {}, {});
    264 }
    265 
    266 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
    267   ComputationBuilder builder(client_, TestName());
    268   auto a = builder.ConstantR1<complex64>(
    269       {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
    270   auto b = builder.ConstantR1<complex64>(
    271       {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
    272   auto add = builder.Sub(a, b);
    273 
    274   ComputeAndCompareR1<complex64>(
    275       &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
    276       error_spec_);
    277 }
    278 
    279 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
    280   ComputationBuilder builder(client_, TestName());
    281   auto a = builder.ConstantR1<complex64>({});
    282   auto b = builder.ConstantR1<complex64>({});
    283   auto add = builder.Sub(a, b);
    284 
    285   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
    286 }
    287 
    288 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
    289   ComputationBuilder builder(client_, TestName());
    290   auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
    291   auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
    292   auto add = builder.Div(a, b);
    293 
    294   ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
    295                              error_spec_);
    296 }
    297 
    298 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
    299   ComputationBuilder builder(client_, TestName());
    300   auto a = builder.ConstantR1<float>({});
    301   auto b = builder.ConstantR1<float>({});
    302   auto add = builder.Div(a, b);
    303 
    304   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    305 }
    306 
    307 XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
    308   // clang-format off
    309   // Some interesting values to test.
    310   std::vector<int32> vals = {
    311     INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
    312     -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
    313     7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
    314   // clang-format on
    315 
    316   std::vector<int32> dividends, divisors, quotients, remainders;
    317   for (int32 divisor : vals) {
    318     if (divisor != 0) {
    319       for (int32 dividend : vals) {
    320         // Avoid integer overflow.
    321         if (dividend != INT32_MIN || divisor != -1) {
    322           dividends.push_back(dividend);
    323           divisors.push_back(divisor);
    324           quotients.push_back(dividend / divisor);
    325           remainders.push_back(dividend % divisor);
    326         }
    327       }
    328     }
    329   }
    330 
    331   {
    332     ComputationBuilder builder(client_, TestName());
    333     ComputationDataHandle dividend;
    334     ComputationDataHandle divisor;
    335     auto dividend_data =
    336         CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
    337     auto divisor_data =
    338         CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
    339     builder.Div(dividend, divisor);
    340 
    341     ComputeAndCompareR1<int32>(&builder, quotients,
    342                                {dividend_data.get(), divisor_data.get()});
    343   }
    344 
    345   // Test with a compile-time constant divisor.
    346   {
    347     ComputationBuilder builder(client_, TestName());
    348     ComputationDataHandle dividend;
    349     auto dividend_data =
    350         CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
    351     builder.Div(dividend, builder.ConstantR1<int32>(divisors));
    352 
    353     ComputeAndCompareR1<int32>(&builder, quotients, {dividend_data.get()});
    354   }
    355 
    356   {
    357     ComputationBuilder builder(client_, TestName());
    358     ComputationDataHandle dividend;
    359     ComputationDataHandle divisor;
    360     auto dividend_data =
    361         CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
    362     auto divisor_data =
    363         CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
    364     builder.Rem(dividend, divisor);
    365 
    366     ComputeAndCompareR1<int32>(&builder, remainders,
    367                                {dividend_data.get(), divisor_data.get()});
    368   }
    369 
    370   // Test with a compile-time constant divisor.
    371   {
    372     ComputationBuilder builder(client_, TestName());
    373     ComputationDataHandle dividend;
    374     auto dividend_data =
    375         CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
    376     builder.Rem(dividend, builder.ConstantR1<int32>(divisors));
    377 
    378     ComputeAndCompareR1<int32>(&builder, remainders, {dividend_data.get()});
    379   }
    380 }
    381 
    382 XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
    383   // clang-format off
    384   // Some interesting values to test.
    385   std::vector<uint32> vals = {
    386     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
    387     0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
    388   // clang-format on
    389 
    390   std::vector<uint32> dividends, divisors, quotients, remainders;
    391   for (uint32 divisor : vals) {
    392     if (divisor != 0) {
    393       for (uint32 dividend : vals) {
    394         dividends.push_back(dividend);
    395         divisors.push_back(divisor);
    396         quotients.push_back(dividend / divisor);
    397         remainders.push_back(dividend % divisor);
    398       }
    399     }
    400   }
    401 
    402   {
    403     ComputationBuilder builder(client_, TestName());
    404     ComputationDataHandle dividend;
    405     ComputationDataHandle divisor;
    406     auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
    407                                                    &builder, &dividend);
    408     auto divisor_data =
    409         CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
    410     builder.Div(dividend, divisor);
    411 
    412     ComputeAndCompareR1<uint32>(&builder, quotients,
    413                                 {dividend_data.get(), divisor_data.get()});
    414   }
    415 
    416   {
    417     ComputationBuilder builder(client_, TestName());
    418     ComputationDataHandle dividend;
    419     auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
    420                                                    &builder, &dividend);
    421     builder.Div(dividend, builder.ConstantR1<uint32>(divisors));
    422 
    423     ComputeAndCompareR1<uint32>(&builder, quotients, {dividend_data.get()});
    424   }
    425 
    426   {
    427     ComputationBuilder builder(client_, TestName());
    428     ComputationDataHandle dividend;
    429     ComputationDataHandle divisor;
    430     auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
    431                                                    &builder, &dividend);
    432     auto divisor_data =
    433         CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
    434     builder.Rem(dividend, divisor);
    435 
    436     ComputeAndCompareR1<uint32>(&builder, remainders,
    437                                 {dividend_data.get(), divisor_data.get()});
    438   }
    439 
    440   {
    441     ComputationBuilder builder(client_, TestName());
    442     ComputationDataHandle dividend;
    443     auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
    444                                                    &builder, &dividend);
    445     builder.Rem(dividend, builder.ConstantR1<uint32>(divisors));
    446 
    447     ComputeAndCompareR1<uint32>(&builder, remainders, {dividend_data.get()});
    448   }
    449 }
    450 
    451 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
    452   ComputationBuilder builder(client_, TestName());
    453   auto a = builder.ConstantR1<complex64>(
    454       {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
    455   auto b = builder.ConstantR1<complex64>(
    456       {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
    457   auto div = builder.Div(a, b);
    458 
    459   ComputeAndCompareR1<complex64>(
    460       &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
    461 }
    462 
    463 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
    464   ComputationBuilder builder(client_, TestName());
    465   auto a = builder.ConstantR1<complex64>({});
    466   auto b = builder.ConstantR1<complex64>({});
    467   auto div = builder.Div(a, b);
    468 
    469   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
    470 }
    471 
    472 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
    473   ComputationBuilder builder(client_, TestName());
    474   auto a = builder.ConstantR1<float>(
    475       {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
    476   auto b = builder.ConstantR1<float>(
    477       {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
    478   auto add = builder.Rem(a, b);
    479 
    480   ComputeAndCompareR1<float>(
    481       &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
    482       error_spec_);
    483 }
    484 
    485 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
    486   ComputationBuilder builder(client_, TestName());
    487   auto a = builder.ConstantR1<float>({});
    488   auto b = builder.ConstantR1<float>({});
    489   auto add = builder.Rem(a, b);
    490 
    491   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    492 }
    493 
    494 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
    495   ComputationBuilder builder(client_, TestName());
    496   auto a = builder.ConstantR1<double>(
    497       {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
    498   auto b = builder.ConstantR1<double>(
    499       {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
    500   auto add = builder.Rem(a, b);
    501 
    502   ComputeAndCompareR1<double>(
    503       &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
    504       error_spec_);
    505 }
    506 
    507 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
    508   ComputationBuilder builder(client_, TestName());
    509   auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
    510   auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
    511   auto add = builder.Mul(a, b);
    512 
    513   ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
    514                              {}, error_spec_);
    515 }
    516 
    517 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
    518   ComputationBuilder builder(client_, TestName());
    519   auto a = builder.ConstantR1<float>({});
    520   auto b = builder.ConstantR1<float>({});
    521   auto add = builder.Mul(a, b);
    522 
    523   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    524 }
    525 
    526 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
    527   std::vector<int32> data = {0,
    528                              1,
    529                              -1,
    530                              1234,
    531                              0x1a243514,
    532                              std::numeric_limits<int32>::max(),
    533                              std::numeric_limits<int32>::min()};
    534   // Form the test data set using all products of 'data' with itself.
    535   std::vector<int32> a_data, b_data, expected;
    536   for (int32 a : data) {
    537     for (int32 b : data) {
    538       a_data.push_back(a);
    539       b_data.push_back(b);
    540       expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
    541     }
    542   }
    543 
    544   ComputationBuilder builder(client_, TestName());
    545   auto a = builder.ConstantR1<int32>(a_data);
    546   auto b = builder.ConstantR1<int32>(b_data);
    547   auto add = builder.Mul(a, b);
    548 
    549   ComputeAndCompareR1<int32>(&builder, expected, {});
    550 }
    551 
    552 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
    553   ComputationBuilder builder(client_, TestName());
    554   auto a = builder.ConstantR1<int32>({});
    555   auto b = builder.ConstantR1<int32>({});
    556   auto add = builder.Mul(a, b);
    557 
    558   ComputeAndCompareR1<int32>(&builder, {}, {});
    559 }
    560 
    561 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
    562   std::vector<uint32> data = {0,          1,          0xDEADBEEF, 1234,
    563                               0x1a243514, 0xFFFFFFFF, 0x80808080};
    564 
    565   // Form the test data set using all products of 'data' with itself.
    566   std::vector<uint32> a_data, b_data, expected;
    567   for (uint32 a : data) {
    568     for (uint32 b : data) {
    569       a_data.push_back(a);
    570       b_data.push_back(b);
    571       expected.push_back(a * b);
    572     }
    573   }
    574 
    575   ComputationBuilder builder(client_, TestName());
    576   auto a = builder.ConstantR1<uint32>(a_data);
    577   auto b = builder.ConstantR1<uint32>(b_data);
    578   auto add = builder.Mul(a, b);
    579 
    580   ComputeAndCompareR1<uint32>(&builder, expected, {});
    581 }
    582 
    583 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
    584   ComputationBuilder builder(client_, TestName());
    585   auto a = builder.ConstantR1<complex64>(
    586       {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
    587   auto b = builder.ConstantR1<complex64>(
    588       {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
    589   auto add = builder.Mul(a, b);
    590 
    591   ComputeAndCompareR1<complex64>(
    592       &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
    593       error_spec_);
    594 }
    595 
    596 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
    597   ComputationBuilder builder(client_, TestName());
    598   auto a = builder.ConstantR1<complex64>({});
    599   auto b = builder.ConstantR1<complex64>({});
    600   auto add = builder.Mul(a, b);
    601 
    602   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
    603 }
    604 
    605 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
    606   ComputationBuilder builder(client_, TestName());
    607   auto a = builder.ConstantR1<bool>({false, false, true, true});
    608   auto b = builder.ConstantR1<bool>({false, true, false, true});
    609   auto out = builder.And(a, b);
    610 
    611   ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
    612 }
    613 
    614 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
    615   ComputationBuilder builder(client_, TestName());
    616   auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
    617   auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
    618   auto out = builder.And(a, b);
    619 
    620   Array2D<bool> expected_array({{false, false}, {false, true}});
    621   ComputeAndCompareR2<bool>(&builder, expected_array, {});
    622 }
    623 
    624 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
    625   ComputationBuilder builder(client_, TestName());
    626   auto a = builder.ConstantR1<bool>({});
    627   auto b = builder.ConstantR1<bool>({});
    628   auto out = builder.And(a, b);
    629 
    630   ComputeAndCompareR1<bool>(&builder, {}, {});
    631 }
    632 
    633 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
    634   ComputationBuilder builder(client_, TestName());
    635   auto a = builder.ConstantR1<int32>({0, -1, -8});
    636   auto b = builder.ConstantR1<int32>({5, -7, 12});
    637   auto out = builder.And(a, b);
    638 
    639   ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
    640 }
    641 
    642 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
    643   ComputationBuilder builder(client_, TestName());
    644   auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}});
    645   auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}});
    646   auto out = builder.And(a, b);
    647 
    648   Array2D<int32> expected_array({{0, -6}, {4, 5}});
    649   ComputeAndCompareR2<int32>(&builder, expected_array, {});
    650 }
    651 
    652 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
    653   ComputationBuilder builder(client_, TestName());
    654   auto a = builder.ConstantR1<int32>({});
    655   auto b = builder.ConstantR1<int32>({});
    656   auto out = builder.And(a, b);
    657 
    658   ComputeAndCompareR1<int32>(&builder, {}, {});
    659 }
    660 
    661 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
    662   ComputationBuilder builder(client_, TestName());
    663   auto a = builder.ConstantR1<int32>({0, 1, 8});
    664   auto b = builder.ConstantR1<int32>({5, 7, 12});
    665   auto out = builder.And(a, b);
    666 
    667   ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
    668 }
    669 
    670 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
    671   ComputationBuilder builder(client_, TestName());
    672   auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}});
    673   auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}});
    674   auto out = builder.And(a, b);
    675 
    676   Array2D<uint32> expected_array({{0, 0}, {3, 0}});
    677   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
    678 }
    679 
    680 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
    681   ComputationBuilder builder(client_, TestName());
    682   auto a = builder.ConstantR1<uint32>({});
    683   auto b = builder.ConstantR1<uint32>({});
    684   auto out = builder.And(a, b);
    685 
    686   ComputeAndCompareR1<uint32>(&builder, {}, {});
    687 }
    688 
    689 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
    690   ComputationBuilder builder(client_, TestName());
    691   auto a = builder.ConstantR1<bool>({false, false, true, true});
    692   auto b = builder.ConstantR1<bool>({false, true, false, true});
    693   auto out = builder.Or(a, b);
    694 
    695   ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
    696 }
    697 
    698 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
    699   ComputationBuilder builder(client_, TestName());
    700   auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
    701   auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
    702   auto out = builder.Or(a, b);
    703 
    704   Array2D<bool> expected_array({{false, true}, {true, true}});
    705   ComputeAndCompareR2<bool>(&builder, expected_array, {});
    706 }
    707 
    708 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
    709   ComputationBuilder builder(client_, TestName());
    710   auto a = builder.ConstantR1<bool>({});
    711   auto b = builder.ConstantR1<bool>({});
    712   auto out = builder.Or(a, b);
    713 
    714   ComputeAndCompareR1<bool>(&builder, {}, {});
    715 }
    716 
    717 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
    718   ComputationBuilder builder(client_, TestName());
    719   auto a = builder.ConstantR1<int32>({0, -1, 8});
    720   auto b = builder.ConstantR1<int32>({5, -7, 4});
    721   auto out = builder.Or(a, b);
    722 
    723   ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
    724 }
    725 
    726 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
    727   ComputationBuilder builder(client_, TestName());
    728   auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}});
    729   auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}});
    730   auto out = builder.Or(a, b);
    731 
    732   Array2D<int32> expected_array({{5, -1}, {12, 9}});
    733   ComputeAndCompareR2<int32>(&builder, expected_array, {});
    734 }
    735 
    736 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
    737   ComputationBuilder builder(client_, TestName());
    738   auto a = builder.ConstantR1<int32>({});
    739   auto b = builder.ConstantR1<int32>({});
    740   auto out = builder.Or(a, b);
    741 
    742   ComputeAndCompareR1<int32>(&builder, {}, {});
    743 }
    744 
    745 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
    746   ComputationBuilder builder(client_, TestName());
    747   auto a = builder.ConstantR1<uint32>({0, 1, 8});
    748   auto b = builder.ConstantR1<uint32>({5, 7, 4});
    749   auto out = builder.Or(a, b);
    750 
    751   ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
    752 }
    753 
    754 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
    755   ComputationBuilder builder(client_, TestName());
    756   auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}});
    757   auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}});
    758   auto out = builder.Or(a, b);
    759 
    760   Array2D<uint32> expected_array({{5, 7}, {12, 9}});
    761   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
    762 }
    763 
    764 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
    765   ComputationBuilder builder(client_, TestName());
    766   auto a = builder.ConstantR1<uint32>({});
    767   auto b = builder.ConstantR1<uint32>({});
    768   auto out = builder.Or(a, b);
    769 
    770   ComputeAndCompareR1<uint32>(&builder, {}, {});
    771 }
    772 
    773 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
    774   ComputationBuilder builder(client_, TestName());
    775   auto a = builder.ConstantR1<bool>({false, true, true, false});
    776   auto out = builder.Not(a);
    777 
    778   ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
    779 }
    780 
    781 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
    782   ComputationBuilder builder(client_, TestName());
    783   auto a = builder.ConstantR2<bool>({{false, true}, {true, false}});
    784   auto out = builder.Not(a);
    785 
    786   Array2D<bool> expected_array({{true, false}, {false, true}});
    787   ComputeAndCompareR2<bool>(&builder, expected_array, {});
    788 }
    789 
    790 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
    791   ComputationBuilder builder(client_, TestName());
    792   auto a = builder.ConstantR1<bool>({});
    793   auto out = builder.Not(a);
    794 
    795   ComputeAndCompareR1<bool>(&builder, {}, {});
    796 }
    797 
    798 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
    799   ComputationBuilder builder(client_, TestName());
    800   auto a = builder.ConstantR1<int32>({-1, 0, 1});
    801   auto out = builder.Not(a);
    802 
    803   ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
    804 }
    805 
    806 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
    807   ComputationBuilder builder(client_, TestName());
    808   auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}});
    809   auto out = builder.Not(a);
    810 
    811   Array2D<int32> expected_array({{0, -1}, {-2, -9}});
    812   ComputeAndCompareR2<int32>(&builder, expected_array, {});
    813 }
    814 
    815 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
    816   ComputationBuilder builder(client_, TestName());
    817   auto a = builder.ConstantR1<int32>({});
    818   auto out = builder.Not(a);
    819 
    820   ComputeAndCompareR1<int32>(&builder, {}, {});
    821 }
    822 
    823 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
    824   ComputationBuilder builder(client_, TestName());
    825   auto a = builder.ConstantR1<uint32>({0, 4294967295});
    826   auto out = builder.Not(a);
    827 
    828   ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
    829 }
    830 
    831 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
    832   ComputationBuilder builder(client_, TestName());
    833   auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}});
    834   auto out = builder.Not(a);
    835 
    836   Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
    837   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
    838 }
    839 
    840 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
    841   ComputationBuilder builder(client_, TestName());
    842   auto a = builder.ConstantR1<uint32>({});
    843   auto out = builder.Not(a);
    844 
    845   ComputeAndCompareR1<uint32>(&builder, {}, {});
    846 }
    847 
    848 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
    849   ComputationBuilder builder(client_, TestName());
    850   auto a =
    851       builder.ConstantR1<int32>({static_cast<int32>(0x12345678),
    852                                  static_cast<int32>(0xF0001000), 1, 3, 77});
    853   auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15});
    854   auto out = builder.ShiftLeft(a, b);
    855 
    856   ComputeAndCompareR1<int32>(
    857       &builder,
    858       {static_cast<int32>(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {});
    859 }
    860 
    861 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
    862   ComputationBuilder builder(client_, TestName());
    863   auto a =
    864       builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
    865                                  static_cast<int32>(0x10001000), 1, 3, 77});
    866   auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2});
    867   auto out = builder.ShiftRightArithmetic(a, b);
    868 
    869   ComputeAndCompareR1<int32>(&builder,
    870                              {static_cast<int32>(0xF9234567),
    871                               static_cast<int32>(0x00100010), 0, 0, 19},
    872                              {});
    873 }
    874 
    875 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
    876   ComputationBuilder builder(client_, TestName());
    877   auto a =
    878       builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
    879                                  static_cast<int32>(0x10001000), 1, 3, 77});
    880   auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5});
    881   auto out = builder.ShiftRightLogical(a, b);
    882 
    883   ComputeAndCompareR1<int32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
    884 }
    885 
    886 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
    887   ComputationBuilder builder(client_, TestName());
    888   auto a = builder.ConstantR1<uint32>({0x12345678, 0xF0001000, 1, 3, 77});
    889   auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15});
    890   auto out = builder.ShiftLeft(a, b);
    891 
    892   ComputeAndCompareR1<uint32>(
    893       &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {});
    894 }
    895 
    896 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
    897   ComputationBuilder builder(client_, TestName());
    898   auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
    899   auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2});
    900   auto out = builder.ShiftRightArithmetic(a, b);
    901 
    902   ComputeAndCompareR1<uint32>(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {});
    903 }
    904 
    905 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
    906   ComputationBuilder builder(client_, TestName());
    907   auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
    908   auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5});
    909   auto out = builder.ShiftRightLogical(a, b);
    910 
    911   ComputeAndCompareR1<uint32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
    912 }
    913 
    914 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
    915   SetFastMathDisabled(true);
    916   ComputationBuilder builder(client_, TestName());
    917   auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
    918   auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN});
    919   auto compare = builder.Eq(lhs, rhs);
    920 
    921   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
    922 }
    923 
    924 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
    925   ComputationBuilder builder(client_, TestName());
    926   auto lhs = builder.ConstantR1<float>({});
    927   auto rhs = builder.ConstantR1<float>({});
    928   auto compare = builder.Eq(lhs, rhs);
    929 
    930   ComputeAndCompareR1<bool>(&builder, {}, {});
    931 }
    932 
    933 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
    934   SetFastMathDisabled(true);
    935   ComputationBuilder builder(client_, TestName());
    936   auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
    937   auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
    938   auto compare = builder.Ge(lhs, rhs);
    939 
    940   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
    941 }
    942 
    943 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
    944   SetFastMathDisabled(true);
    945   ComputationBuilder builder(client_, TestName());
    946   auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
    947   auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
    948   auto compare = builder.Gt(lhs, rhs);
    949 
    950   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
    951 }
    952 
    953 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
    954   SetFastMathDisabled(true);
    955   ComputationBuilder builder(client_, TestName());
    956   auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
    957   auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
    958   auto compare = builder.Le(lhs, rhs);
    959 
    960   ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
    961 }
    962 
    963 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
    964   SetFastMathDisabled(true);
    965   ComputationBuilder builder(client_, TestName());
    966   auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
    967   auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
    968   auto compare = builder.Lt(lhs, rhs);
    969 
    970   ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
    971 }
    972 
    973 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
    974   const int32 min = std::numeric_limits<int32>::min();
    975   const int32 max = std::numeric_limits<int32>::max();
    976   ComputationBuilder builder(client_, TestName());
    977   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
    978   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
    979   auto compare = builder.Eq(lhs, rhs);
    980 
    981   ComputeAndCompareR1<bool>(
    982       &builder, {true, false, false, false, true, false, false, false, true},
    983       {});
    984 }
    985 
    986 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
    987   ComputationBuilder builder(client_, TestName());
    988   auto lhs = builder.ConstantR1<int32>({});
    989   auto rhs = builder.ConstantR1<int32>({});
    990   auto compare = builder.Eq(lhs, rhs);
    991 
    992   ComputeAndCompareR1<bool>(&builder, {}, {});
    993 }
    994 
    995 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
    996   SetFastMathDisabled(true);
    997   ComputationBuilder builder(client_, TestName());
    998   auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
    999                                             {1.0f, 25.5f},
   1000                                             {2.25f, -3.0f},
   1001                                             {NAN, 0.0f},
   1002                                             {1.0f, 6.0f}});
   1003   auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
   1004                                             {1.0f, 5.0f},
   1005                                             {2.25f, -3.0f},
   1006                                             {10.0f, 0.0f},
   1007                                             {1.0f, NAN}});
   1008   auto compare = builder.Eq(lhs, rhs);
   1009 
   1010   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
   1011 }
   1012 
   1013 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
   1014   ComputationBuilder builder(client_, TestName());
   1015   auto lhs = builder.ConstantR1<complex64>({});
   1016   auto rhs = builder.ConstantR1<complex64>({});
   1017   auto compare = builder.Eq(lhs, rhs);
   1018 
   1019   ComputeAndCompareR1<bool>(&builder, {}, {});
   1020 }
   1021 
   1022 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
   1023   // Disable fast-math because we're operating on NaNs.
   1024   SetFastMathDisabled(true);
   1025 
   1026   ComputationBuilder builder(client_, TestName());
   1027   auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
   1028                                             {1.0f, 25.5f},
   1029                                             {2.25f, -3.0f},
   1030                                             {NAN, 0.0f},
   1031                                             {1.0f, 6.0f}});
   1032   auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
   1033                                             {1.0f, 5.0f},
   1034                                             {2.25f, -3.0f},
   1035                                             {10.0f, 0.0f},
   1036                                             {1.0f, NAN}});
   1037   auto compare = builder.Ne(lhs, rhs);
   1038 
   1039   ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
   1040 }
   1041 
   1042 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
   1043   // Disable fast-math because we're operating on NaNs.
   1044   SetFastMathDisabled(true);
   1045 
   1046   ComputationBuilder builder(client_, TestName());
   1047   auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
   1048   auto rhs = builder.ConstantR1<float>({10.0f, 25.5f, 1.0f, 10.0f, NAN});
   1049   auto compare = builder.Ne(lhs, rhs);
   1050 
   1051   ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
   1052 }
   1053 
   1054 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
   1055   const int32 min = std::numeric_limits<int32>::min();
   1056   const int32 max = std::numeric_limits<int32>::max();
   1057   ComputationBuilder builder(client_, TestName());
   1058   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
   1059   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
   1060   auto compare = builder.Ne(lhs, rhs);
   1061 
   1062   ComputeAndCompareR1<bool>(
   1063       &builder, {false, true, true, true, false, true, true, true, false}, {});
   1064 }
   1065 
   1066 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
   1067   const int32 min = std::numeric_limits<int32>::min();
   1068   const int32 max = std::numeric_limits<int32>::max();
   1069   ComputationBuilder builder(client_, TestName());
   1070   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
   1071   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
   1072   auto compare = builder.Ge(lhs, rhs);
   1073 
   1074   ComputeAndCompareR1<bool>(
   1075       &builder, {true, false, false, true, true, false, true, true, true}, {});
   1076 }
   1077 
   1078 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
   1079   const int32 min = std::numeric_limits<int32>::min();
   1080   const int32 max = std::numeric_limits<int32>::max();
   1081   ComputationBuilder builder(client_, TestName());
   1082   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
   1083   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
   1084   auto compare = builder.Gt(lhs, rhs);
   1085 
   1086   ComputeAndCompareR1<bool>(
   1087       &builder, {false, false, false, true, false, false, true, true, false},
   1088       {});
   1089 }
   1090 
   1091 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
   1092   const int32 min = std::numeric_limits<int32>::min();
   1093   const int32 max = std::numeric_limits<int32>::max();
   1094   ComputationBuilder builder(client_, TestName());
   1095   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
   1096   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
   1097   auto compare = builder.Le(lhs, rhs);
   1098 
   1099   ComputeAndCompareR1<bool>(
   1100       &builder, {true, true, true, false, true, true, false, false, true}, {});
   1101 }
   1102 
   1103 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
   1104   const int32 min = std::numeric_limits<int32>::min();
   1105   const int32 max = std::numeric_limits<int32>::max();
   1106   ComputationBuilder builder(client_, TestName());
   1107   auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
   1108   auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
   1109   auto compare = builder.Lt(lhs, rhs);
   1110 
   1111   ComputeAndCompareR1<bool>(
   1112       &builder, {false, true, true, false, false, true, false, false, false},
   1113       {});
   1114 }
   1115 
   1116 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
   1117   const uint32 max = std::numeric_limits<uint32>::max();
   1118   ComputationBuilder builder(client_, TestName());
   1119   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1120   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1121   auto compare = builder.Eq(lhs, rhs);
   1122 
   1123   ComputeAndCompareR1<bool>(
   1124       &builder, {true, false, false, false, true, false, false, false, true},
   1125       {});
   1126 }
   1127 
   1128 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
   1129   const uint32 max = std::numeric_limits<uint32>::max();
   1130   ComputationBuilder builder(client_, TestName());
   1131   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1132   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1133   auto compare = builder.Ne(lhs, rhs);
   1134 
   1135   ComputeAndCompareR1<bool>(
   1136       &builder, {false, true, true, true, false, true, true, true, false}, {});
   1137 }
   1138 
   1139 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
   1140   const uint32 max = std::numeric_limits<uint32>::max();
   1141   ComputationBuilder builder(client_, TestName());
   1142   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1143   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1144   auto compare = builder.Ge(lhs, rhs);
   1145 
   1146   ComputeAndCompareR1<bool>(
   1147       &builder, {true, false, false, true, true, false, true, true, true}, {});
   1148 }
   1149 
   1150 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
   1151   const uint32 max = std::numeric_limits<uint32>::max();
   1152   ComputationBuilder builder(client_, TestName());
   1153   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1154   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1155   auto compare = builder.Gt(lhs, rhs);
   1156 
   1157   ComputeAndCompareR1<bool>(
   1158       &builder, {false, false, false, true, false, false, true, true, false},
   1159       {});
   1160 }
   1161 
   1162 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
   1163   const uint32 max = std::numeric_limits<uint32>::max();
   1164   ComputationBuilder builder(client_, TestName());
   1165   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1166   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1167   auto compare = builder.Le(lhs, rhs);
   1168 
   1169   ComputeAndCompareR1<bool>(
   1170       &builder, {true, true, true, false, true, true, false, false, true}, {});
   1171 }
   1172 
   1173 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
   1174   const uint32 max = std::numeric_limits<uint32>::max();
   1175   ComputationBuilder builder(client_, TestName());
   1176   auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
   1177   auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
   1178   auto compare = builder.Lt(lhs, rhs);
   1179 
   1180   ComputeAndCompareR1<bool>(
   1181       &builder, {false, true, true, false, false, true, false, false, false},
   1182       {});
   1183 }
   1184 
   1185 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
   1186   SetFastMathDisabled(true);
   1187   ComputationBuilder builder(client_, TestName());
   1188   auto lhs =
   1189       builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
   1190   auto rhs =
   1191       builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
   1192   auto minimum = builder.Pow(lhs, rhs);
   1193 
   1194   ComputeAndCompareR1<float>(
   1195       &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
   1196 }
   1197 
   1198 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
   1199   SetFastMathDisabled(true);
   1200   ComputationBuilder builder(client_, TestName());
   1201   auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f});
   1202   auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f});
   1203   auto minimum = builder.Pow(lhs, rhs);
   1204 
   1205   ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
   1206                              error_spec_);
   1207 }
   1208 
   1209 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
   1210   ComputationBuilder builder(client_, TestName());
   1211   auto lhs = builder.ConstantR1<float>({});
   1212   auto rhs = builder.ConstantR1<float>({});
   1213   auto minimum = builder.Pow(lhs, rhs);
   1214 
   1215   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
   1216 }
   1217 
   1218 // Some Pow cases that can be implemented more efficiently.
   1219 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
   1220   ComputationBuilder b(client_, TestName());
   1221 
   1222   std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
   1223   std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1224 
   1225   std::unique_ptr<Literal> param_literal = Literal::CreateR1<float>(values);
   1226   std::unique_ptr<GlobalData> param_data =
   1227       client_->TransferToServer(*param_literal).ConsumeValueOrDie();
   1228 
   1229   auto sum = b.ConstantR0<float>(0.0f);
   1230   auto param = b.Parameter(0, param_literal->shape(), "param");
   1231   for (float exponent : exponents) {
   1232     sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent)));
   1233   }
   1234 
   1235   std::vector<float> expected;
   1236   for (auto value : values) {
   1237     float sum = 0.0f;
   1238     for (float exponent : exponents) {
   1239       sum += std::pow(value, exponent);
   1240     }
   1241     expected.push_back(sum);
   1242   }
   1243 
   1244   ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
   1245 }
   1246 
   1247 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
   1248   ComputationBuilder b(client_, TestName());
   1249 
   1250   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   1251   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1252 
   1253   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1254   std::unique_ptr<GlobalData> data0 =
   1255       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1256   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1257   std::unique_ptr<GlobalData> data1 =
   1258       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1259   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1260   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1261   b.Pow(b.Exp(param0), param1);
   1262 
   1263   std::vector<float> expected(values0.size());
   1264   for (int64 i = 0; i < values0.size(); ++i) {
   1265     expected[i] = std::pow(std::exp(values0[i]), values1[i]);
   1266   }
   1267 
   1268   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
   1269                              error_spec_);
   1270 }
   1271 
   1272 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
   1273   ComputationBuilder b(client_, TestName());
   1274 
   1275   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
   1276   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1277 
   1278   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1279   std::unique_ptr<GlobalData> data0 =
   1280       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1281   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1282   std::unique_ptr<GlobalData> data1 =
   1283       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1284   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1285   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1286   b.Log(b.Pow(param0, param1));
   1287 
   1288   std::vector<float> expected(values0.size());
   1289   for (int64 i = 0; i < values0.size(); ++i) {
   1290     expected[i] = std::log(std::pow(values0[i], values1[i]));
   1291   }
   1292 
   1293   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
   1294                              error_spec_);
   1295 }
   1296 
   1297 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
   1298   ComputationBuilder b(client_, TestName());
   1299 
   1300   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   1301   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1302 
   1303   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1304   std::unique_ptr<GlobalData> data0 =
   1305       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1306   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1307   std::unique_ptr<GlobalData> data1 =
   1308       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1309   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1310   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1311   b.Mul(b.Exp(param0), b.Exp(param1));
   1312 
   1313   std::vector<float> expected(values0.size());
   1314   for (int64 i = 0; i < values0.size(); ++i) {
   1315     expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
   1316   }
   1317 
   1318   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
   1319                              error_spec_);
   1320 }
   1321 
   1322 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
   1323   ComputationBuilder b(client_, TestName());
   1324 
   1325   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   1326   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1327 
   1328   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1329   std::unique_ptr<GlobalData> data0 =
   1330       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1331   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1332   std::unique_ptr<GlobalData> data1 =
   1333       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1334   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1335   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1336   b.Div(param0, b.Exp(param1));
   1337 
   1338   std::vector<float> expected(values0.size());
   1339   for (int64 i = 0; i < values0.size(); ++i) {
   1340     expected[i] = values0[i] / std::exp(values1[i]);
   1341   }
   1342 
   1343   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
   1344                              error_spec_);
   1345 }
   1346 
   1347 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
   1348   ComputationBuilder b(client_, TestName());
   1349 
   1350   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
   1351   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1352   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
   1353 
   1354   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1355   std::unique_ptr<GlobalData> data0 =
   1356       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1357 
   1358   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1359   std::unique_ptr<GlobalData> data1 =
   1360       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1361 
   1362   std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
   1363   std::unique_ptr<GlobalData> data2 =
   1364       client_->TransferToServer(*literal2).ConsumeValueOrDie();
   1365   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1366   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1367   auto param2 = b.Parameter(2, literal2->shape(), "param2");
   1368   b.Div(b.Div(param0, param1), param2);
   1369 
   1370   std::vector<float> expected(values0.size());
   1371   for (int64 i = 0; i < values0.size(); ++i) {
   1372     expected[i] = (values0[i] / values1[i]) / values2[i];
   1373   }
   1374 
   1375   ComputeAndCompareR1<float>(
   1376       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
   1377 }
   1378 
   1379 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
   1380   ComputationBuilder b(client_, TestName());
   1381 
   1382   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
   1383   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1384   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
   1385 
   1386   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1387   std::unique_ptr<GlobalData> data0 =
   1388       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1389 
   1390   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1391   std::unique_ptr<GlobalData> data1 =
   1392       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1393 
   1394   std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
   1395   std::unique_ptr<GlobalData> data2 =
   1396       client_->TransferToServer(*literal2).ConsumeValueOrDie();
   1397 
   1398   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1399   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1400   auto param2 = b.Parameter(2, literal2->shape(), "param2");
   1401   b.Div(param0, b.Div(param1, param2));
   1402 
   1403   std::vector<float> expected(values0.size());
   1404   for (int64 i = 0; i < values0.size(); ++i) {
   1405     expected[i] = values0[i] / (values1[i] / values2[i]);
   1406   }
   1407 
   1408   ComputeAndCompareR1<float>(
   1409       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
   1410 }
   1411 
   1412 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
   1413   ComputationBuilder b(client_, TestName());
   1414 
   1415   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
   1416   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
   1417   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
   1418 
   1419   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1420   std::unique_ptr<GlobalData> data0 =
   1421       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1422 
   1423   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1424   std::unique_ptr<GlobalData> data1 =
   1425       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1426 
   1427   std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
   1428   std::unique_ptr<GlobalData> data2 =
   1429       client_->TransferToServer(*literal2).ConsumeValueOrDie();
   1430 
   1431   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1432   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1433   auto param2 = b.Parameter(2, literal2->shape(), "param2");
   1434   b.Div(param0, b.Pow(param1, param2));
   1435 
   1436   std::vector<float> expected(values0.size());
   1437   for (int64 i = 0; i < values0.size(); ++i) {
   1438     expected[i] = values0[i] / std::pow(values1[i], values2[i]);
   1439   }
   1440 
   1441   ComputeAndCompareR1<float>(
   1442       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
   1443 }
   1444 
   1445 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
   1446   ComputationBuilder b(client_, TestName());
   1447 
   1448   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
   1449   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   1450   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
   1451   std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
   1452 
   1453   std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
   1454   std::unique_ptr<GlobalData> data0 =
   1455       client_->TransferToServer(*literal0).ConsumeValueOrDie();
   1456 
   1457   std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
   1458   std::unique_ptr<GlobalData> data1 =
   1459       client_->TransferToServer(*literal1).ConsumeValueOrDie();
   1460 
   1461   std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
   1462   std::unique_ptr<GlobalData> data2 =
   1463       client_->TransferToServer(*literal2).ConsumeValueOrDie();
   1464 
   1465   std::unique_ptr<Literal> literal3 = Literal::CreateR1<float>(values3);
   1466   std::unique_ptr<GlobalData> data3 =
   1467       client_->TransferToServer(*literal3).ConsumeValueOrDie();
   1468 
   1469   auto param0 = b.Parameter(0, literal0->shape(), "param0");
   1470   auto param1 = b.Parameter(1, literal1->shape(), "param1");
   1471   auto param2 = b.Parameter(2, literal2->shape(), "param2");
   1472   auto param3 = b.Parameter(3, literal3->shape(), "param2");
   1473   b.Div(b.Div(param0, param1), b.Div(param2, param3));
   1474 
   1475   std::vector<float> expected(values0.size());
   1476   for (int64 i = 0; i < values0.size(); ++i) {
   1477     expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
   1478   }
   1479 
   1480   ComputeAndCompareR1<float>(
   1481       &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
   1482       error_spec_);
   1483 }
   1484 
   1485 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
   1486   const int count = GetParam();
   1487   ComputationBuilder builder(client_, TestName());
   1488   std::vector<float> values;
   1489   values.reserve(count);
   1490   for (int i = 0; i < count; ++i) {
   1491     values.push_back(i / static_cast<float>(count));
   1492   }
   1493   auto x = builder.ConstantR1<float>(values);
   1494   auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
   1495 
   1496   std::vector<float> expected;
   1497   expected.reserve(values.size());
   1498   for (float value : values) {
   1499     expected.push_back(value * value);
   1500   }
   1501 
   1502   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
   1503 }
   1504 
   1505 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
   1506   ComputationBuilder builder(client_, TestName());
   1507   Array4D<float> values(2, 2, 2, 2);
   1508 
   1509   std::vector<float> values_vector;
   1510   std::vector<float> expected_vector;
   1511   for (int i = 0; i < values.num_elements(); ++i) {
   1512     values_vector.push_back(static_cast<float>(i) / values.num_elements());
   1513     expected_vector.push_back(values_vector.back() * values_vector.back());
   1514   }
   1515   values.SetValues(values_vector);
   1516 
   1517   Array4D<float> expected(2, 2, 2, 2, expected_vector);
   1518 
   1519   auto x = builder.ConstantR4FromArray4D<float>(values);
   1520   auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
   1521 
   1522   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
   1523 }
   1524 
   1525 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
   1526   ComputationBuilder builder(client_, TestName());
   1527   Array4D<float> values(2, 2, 0, 2);
   1528   Array4D<float> expected(2, 2, 0, 2);
   1529 
   1530   auto x = builder.ConstantR4FromArray4D<float>(values);
   1531   auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
   1532 
   1533   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
   1534 }
   1535 
   1536 // GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT
   1537 // such
   1538 // * fmin(NaN, x) = x
   1539 // * fmax(NaN, x) = x
   1540 // so we only test NAN on CPU.
   1541 //
   1542 // TODO(b/28180546): Make this compile in a way that is consistent
   1543 // among backends.
   1544 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
   1545   ComputationBuilder builder(client_, TestName());
   1546 #if !defined(XLA_TEST_BACKEND_CPU)
   1547   auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
   1548   auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
   1549 #else
   1550   SetFastMathDisabled(true);
   1551   auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
   1552   auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
   1553 #endif
   1554   auto minimum = builder.Min(lhs, rhs);
   1555 
   1556   ComputeAndCompareR1<float>(&builder,
   1557 #if !defined(XLA_TEST_BACKEND_CPU)
   1558                              {1.0f, -5.0f, 1.0f},
   1559 #else
   1560                              {1.0f, -5.0f, 1.0f, 10.0f, 6.0f},
   1561 #endif
   1562                              {}, error_spec_);
   1563 }
   1564 
   1565 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
   1566   ComputationBuilder builder(client_, TestName());
   1567   auto lhs = builder.ConstantR1<float>({});
   1568   auto rhs = builder.ConstantR1<float>({});
   1569   auto minimum = builder.Min(lhs, rhs);
   1570   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
   1571 }
   1572 
   1573 // TODO(b/28180546): Make this compile in a way that is consistent
   1574 // among backends. See comment on MinF32s test above.
   1575 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
   1576   ComputationBuilder builder(client_, TestName());
   1577 #if !defined(XLA_TEST_BACKEND_CPU)
   1578   auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
   1579   auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
   1580 #else
   1581   SetFastMathDisabled(true);
   1582   auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
   1583   auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
   1584 #endif
   1585   auto minimum = builder.Min(lhs, rhs);
   1586 
   1587   ComputeAndCompareR1<double>(&builder,
   1588 #if !defined(XLA_TEST_BACKEND_CPU)
   1589                               {1.0, -5.0, 1.0},
   1590 #else
   1591                               {1.0, -5.0, 1.0, 10.0, 6.0},
   1592 #endif
   1593                               {}, error_spec_);
   1594 }
   1595 
   1596 // TODO(b/28180546): Make this compile in a way that is consistent
   1597 // among backends. See comment on MinF32s test above.
   1598 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
   1599   ComputationBuilder builder(client_, TestName());
   1600 #if !defined(XLA_TEST_BACKEND_CPU)
   1601   auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
   1602   auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
   1603 #else
   1604   SetFastMathDisabled(true);
   1605   auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
   1606   auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
   1607 #endif
   1608   auto maximum = builder.Max(lhs, rhs);
   1609 
   1610   ComputeAndCompareR1<float>(&builder,
   1611 #if !defined(XLA_TEST_BACKEND_CPU)
   1612                              {2.0f, 1.0f, 2.25f},
   1613 #else
   1614                              {2.0f, 1.0f, 2.25f, 10.0f, 6.0f},
   1615 #endif
   1616                              {}, error_spec_);
   1617 }
   1618 
   1619 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
   1620   ComputationBuilder builder(client_, TestName());
   1621   auto lhs = builder.ConstantR1<float>({});
   1622   auto rhs = builder.ConstantR1<float>({});
   1623   auto minimum = builder.Max(lhs, rhs);
   1624   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
   1625 }
   1626 
   1627 // TODO(b/28180546): Make this compile in a way that is consistent
   1628 // among backends. See comment on MinF32s test above.
   1629 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
   1630   ComputationBuilder builder(client_, TestName());
   1631 #if !defined(XLA_TEST_BACKEND_CPU)
   1632   auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
   1633   auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
   1634 #else
   1635   SetFastMathDisabled(true);
   1636   auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
   1637   auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
   1638 #endif
   1639   auto maximum = builder.Max(lhs, rhs);
   1640 
   1641   ComputeAndCompareR1<double>(&builder,
   1642 #if !defined(XLA_TEST_BACKEND_CPU)
   1643                               {2.0, 1.0, 2.25},
   1644 #else
   1645                               {2.0, 1.0, 2.25, 10.0, 6.0},
   1646 #endif
   1647                               {}, error_spec_);
   1648 }
   1649 
   1650 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
   1651   const int32 min = std::numeric_limits<int32>::min();
   1652   const int32 max = std::numeric_limits<int32>::max();
   1653   ComputationBuilder builder(client_, TestName());
   1654   auto x = builder.ConstantR1<int32>(
   1655       {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
   1656   auto y = builder.ConstantR1<int32>(
   1657       {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
   1658   builder.Max(x, y);
   1659 
   1660   std::vector<int32> expected = {min, max, 0,  -1,  0,   0,  0,
   1661                                  1,   1,   10, max, max, max};
   1662   ComputeAndCompareR1<int32>(&builder, expected, {});
   1663 }
   1664 
   1665 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
   1666   const int32 min = std::numeric_limits<int32>::min();
   1667   const int32 max = std::numeric_limits<int32>::max();
   1668   ComputationBuilder builder(client_, TestName());
   1669   auto x = builder.ConstantR1<int32>(
   1670       {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
   1671   auto y = builder.ConstantR1<int32>(
   1672       {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
   1673   builder.Min(x, y);
   1674 
   1675   std::vector<int32> expected = {min, min, min, -10, -1,  -1, 0,
   1676                                  0,   0,   1,   0,   max, min};
   1677   ComputeAndCompareR1<int32>(&builder, expected, {});
   1678 }
   1679 
   1680 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
   1681   const uint32 max = std::numeric_limits<uint32>::max();
   1682   ComputationBuilder builder(client_, TestName());
   1683   auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
   1684   auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
   1685   builder.Max(x, y);
   1686 
   1687   std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
   1688   ComputeAndCompareR1<uint32>(&builder, expected, {});
   1689 }
   1690 
   1691 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
   1692   const uint32 max = std::numeric_limits<uint32>::max();
   1693   ComputationBuilder builder(client_, TestName());
   1694   auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
   1695   auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
   1696   builder.Min(x, y);
   1697 
   1698   std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
   1699   ComputeAndCompareR1<uint32>(&builder, expected, {});
   1700 }
   1701 
   1702 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
   1703   ComputationBuilder builder(client_, TestName());
   1704   auto x = builder.ConstantR1<float>(
   1705       {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
   1706   auto y = builder.ConstantR1<float>(
   1707       {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
   1708   builder.Max(x, y);
   1709 
   1710   std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
   1711                                  5.0,  6.0, 7.0, 8.0, 9.0};
   1712   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
   1713 }
   1714 
   1715 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
   1716   ComputationBuilder builder(client_, TestName());
   1717   auto u = builder.ConstantR1<float>({3.5});
   1718   auto v = builder.ConstantR1<float>({});
   1719   builder.Max(u, v);
   1720 
   1721   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
   1722 }
   1723 
   1724 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
   1725   for (int broadcast_dim : {0, 1}) {
   1726     ComputationBuilder builder(client_, TestName());
   1727     auto u = builder.ConstantR1<float>({3.5});
   1728     auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
   1729     builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
   1730 
   1731     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
   1732   }
   1733 }
   1734 
   1735 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
   1736   ComputationBuilder builder(client_, TestName());
   1737   auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
   1738   auto m =
   1739       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   1740   builder.Max(v, m, /*broadcast_dimensions=*/{1});
   1741 
   1742   Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
   1743   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
   1744 }
   1745 
   1746 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
   1747   ComputationBuilder builder(client_, TestName());
   1748   auto v = builder.ConstantR1<float>({});
   1749   auto m = builder.ConstantR2<float>({{}, {}});
   1750   builder.Max(v, m, /*broadcast_dimensions=*/{1});
   1751 
   1752   Array2D<float> expected({{}, {}});
   1753   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
   1754 }
   1755 
   1756 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
   1757   ComputationBuilder builder(client_, TestName());
   1758   auto scalar = builder.ConstantR0<int32>(2);
   1759   Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
   1760   auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
   1761   builder.Max(array, scalar, /*broadcast_dimensions=*/{});
   1762 
   1763   Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
   1764   ComputeAndCompareR3<int32>(&builder, expected, {});
   1765 }
   1766 
   1767 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
   1768   ComputationBuilder builder(client_, TestName());
   1769   auto scalar = builder.ConstantR0<int32>(2);
   1770   Array3D<int32> a_3d(2, 0, 3);
   1771   auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
   1772   builder.Max(array, scalar, /*broadcast_dimensions=*/{});
   1773 
   1774   Array3D<int32> expected(2, 0, 3);
   1775   ComputeAndCompareR3<int32>(&builder, expected, {});
   1776 }
   1777 
   1778 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
   1779   ComputationBuilder builder(client_, TestName());
   1780   auto m =
   1781       builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
   1782   auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
   1783   builder.Min(m, v, /*broadcast_dimensions=*/{0});
   1784 
   1785   Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
   1786   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
   1787 }
   1788 
   1789 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
   1790   ComputationBuilder builder(client_, TestName());
   1791   auto m = builder.ConstantR2<float>({{}, {}});
   1792   auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
   1793   builder.Min(m, v, /*broadcast_dimensions=*/{0});
   1794 
   1795   Array2D<float> expected({{}, {}});
   1796   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
   1797 }
   1798 
   1799 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
   1800   ComputationBuilder builder(client_, TestName());
   1801   auto array2d =
   1802       builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
   1803   auto array4d = builder.ConstantR4FromArray4D<float>(
   1804       {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
   1805        {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
   1806   builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
   1807 
   1808   Array4D<float> expected(
   1809       {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
   1810        {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
   1811   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
   1812 }
   1813 
   1814 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
   1815   ComputationBuilder builder(client_, TestName());
   1816   auto array2d =
   1817       builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
   1818   Array4D<float> arg(2, 2, 0, 3);
   1819   auto array4d = builder.ConstantR4FromArray4D<float>(arg);
   1820   builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
   1821 
   1822   Array4D<float> expected(2, 2, 0, 3);
   1823   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
   1824 }
   1825 
   1826 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
   1827   ComputationBuilder builder(client_, TestName());
   1828   auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
   1829   auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
   1830   builder.Min(x, y);
   1831 
   1832   std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
   1833   ComputeAndCompareR1<int32>(&builder, expected, {});
   1834 }
   1835 
   1836 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
   1837   ComputationBuilder builder(client_, TestName());
   1838   auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
   1839   auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
   1840   builder.Max(x, y);
   1841 
   1842   std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
   1843   ComputeAndCompareR1<int32>(&builder, expected, {});
   1844 }
   1845 
   1846 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
   1847   ComputationBuilder builder(client_, TestName());
   1848   auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1});
   1849   auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10});
   1850   auto add = builder.Rem(a, b);
   1851 
   1852   ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
   1853 }
   1854 
   1855 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
   1856   ComputationBuilder builder(client_, TestName());
   1857   auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
   1858   auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
   1859   auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
   1860   auto clamp = builder.Clamp(minimum, argument, maximum);
   1861 
   1862   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
   1863                              error_spec_);
   1864 }
   1865 
   1866 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
   1867   ComputationBuilder builder(client_, TestName());
   1868   auto minimum = builder.ConstantR0<float>(0.0f);
   1869   auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
   1870   auto maximum = builder.ConstantR0<float>(5.0f);
   1871   auto clamp = builder.Clamp(minimum, argument, maximum);
   1872 
   1873   ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
   1874                              error_spec_);
   1875 }
   1876 
   1877 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
   1878   ComputationBuilder builder(client_, TestName());
   1879   auto min_scalar = builder.ConstantR0<float>(0.0f);
   1880   auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
   1881   auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
   1882   auto max_scalar = builder.ConstantR0<float>(3.0f);
   1883   auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
   1884   // Perform clamp with broadcasted scalar and vector.
   1885   auto clamp = builder.Add(
   1886       builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
   1887                   builder.Clamp(min_scalar, arg_vector, max_vector)),
   1888       builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
   1889                   builder.Clamp(min_scalar, arg_vector, max_scalar)));
   1890 
   1891   ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
   1892                              error_spec_);
   1893 }
   1894 
   1895 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
   1896   ComputationBuilder builder(client_, TestName());
   1897   auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
   1898   auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
   1899   auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
   1900   auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
   1901 
   1902   ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
   1903 }
   1904 
   1905 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
   1906   ComputationBuilder builder(client_, TestName());
   1907   auto min_scalar = builder.ConstantR0<int32>(0);
   1908   auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0});
   1909   auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4});
   1910   auto max_scalar = builder.ConstantR0<int32>(3);
   1911   auto max_vector = builder.ConstantR1<int32>({3, 1, 25, 5, 123});
   1912   // Perform clamp with broadcasted scalar and vector.
   1913   auto clamp = builder.Add(
   1914       builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
   1915                   builder.Clamp(min_scalar, arg_vector, max_vector)),
   1916       builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
   1917                   builder.Clamp(min_scalar, arg_vector, max_scalar)));
   1918 
   1919   ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
   1920 }
   1921 
   1922 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
   1923   ComputationBuilder builder(client_, TestName());
   1924   auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
   1925   auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
   1926   auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
   1927   auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
   1928 
   1929   ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
   1930 }
   1931 
   1932 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
   1933   ComputationBuilder builder(client_, TestName());
   1934   auto min_scalar = builder.ConstantR0<uint32>(0);
   1935   auto min_vector = builder.ConstantR1<uint32>({1, 0, 1, 2, 0});
   1936   auto arg_vector = builder.ConstantR1<uint32>({2, 10, 0, 1, 4});
   1937   auto max_scalar = builder.ConstantR0<uint32>(3);
   1938   auto max_vector = builder.ConstantR1<uint32>({3, 1, 25, 5, 123});
   1939   // Perform clamp with broadcasted scalar and vector.
   1940   auto clamp = builder.Add(
   1941       builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
   1942                   builder.Clamp(min_scalar, arg_vector, max_vector)),
   1943       builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
   1944                   builder.Clamp(min_scalar, arg_vector, max_scalar)));
   1945 
   1946   ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
   1947 }
   1948 
   1949 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
   1950   ComputationBuilder builder(client_, TestName());
   1951 
   1952   std::unique_ptr<Literal> param0_literal =
   1953       Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
   1954   std::unique_ptr<GlobalData> param0_data =
   1955       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
   1956 
   1957   std::unique_ptr<Literal> param1_literal =
   1958       Literal::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
   1959   std::unique_ptr<GlobalData> param1_data =
   1960       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
   1961 
   1962   auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
   1963   auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
   1964   auto add = builder.Add(p0, p1);
   1965 
   1966   ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
   1967                              {param0_data.get(), param1_data.get()},
   1968                              error_spec_);
   1969 }
   1970 
   1971 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
   1972   ComputationBuilder builder(client_, TestName());
   1973 
   1974   std::unique_ptr<Literal> param0_literal =
   1975       Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
   1976   std::unique_ptr<GlobalData> param0_data =
   1977       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
   1978 
   1979   std::unique_ptr<Literal> param1_literal =
   1980       Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
   1981   std::unique_ptr<GlobalData> param1_data =
   1982       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
   1983 
   1984   auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
   1985   auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
   1986   auto add = builder.Add(p0, p1);
   1987 
   1988   Array3D<float> expected(0, 7, 0);
   1989   ComputeAndCompareR3<float>(
   1990       &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
   1991 }
   1992 
   1993 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
   1994   ComputationBuilder builder(client_, TestName());
   1995 
   1996   std::unique_ptr<Literal> param0_literal =
   1997       Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
   1998   std::unique_ptr<GlobalData> param0_data =
   1999       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
   2000 
   2001   auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
   2002   auto p = builder.Parameter(0, param0_literal->shape(), "param0");
   2003   auto add = builder.Add(a, p);
   2004 
   2005   ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
   2006                              {param0_data.get()}, error_spec_);
   2007 }
   2008 
   2009 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
   2010   ComputationBuilder builder(client_, TestName());
   2011   auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
   2012   auto result = builder.Cos(a);
   2013 
   2014   ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
   2015                              error_spec_);
   2016 }
   2017 
   2018 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
   2019   ComputationBuilder builder(client_, TestName());
   2020   auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
   2021   auto result = builder.Sin(a);
   2022 
   2023   ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
   2024                              error_spec_);
   2025 }
   2026 
   2027 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
   2028   ComputationBuilder builder(client_, TestName());
   2029   auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
   2030   auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
   2031   auto atan = builder.Atan2(a, b);
   2032 
   2033   ComputeAndCompareR1<float>(
   2034       &builder,
   2035       {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
   2036       {}, error_spec_);
   2037 }
   2038 
   2039 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
   2040   ComputationBuilder builder(client_, TestName());
   2041   auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
   2042   auto result = builder.Tanh(a);
   2043 
   2044   ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
   2045                              error_spec_);
   2046 }
   2047 
   2048 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
   2049   // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
   2050   // the input tensor is large enough to exercise the vectorized tanh
   2051   // implementation on XLA CPU.
   2052   ComputationBuilder builder(client_, TestName());
   2053   auto input_literal = Literal::CreateR1<float>(
   2054       {1.02,  -0.32, 0.85,  0.90,  1.23,  -0.91, -0.49, 0.80,  -0.67, 0.16,
   2055        -0.07, 0.39,  -0.41, 0.04,  1.36,  1.25,  0.41,  0.65,  -1.08, 0.32,
   2056        -1.45, -0.77, -1.09, 0.91,  -1.03, -0.30, -1.11, -1.17, 1.50,  -0.85,
   2057        0.04,  1.02,  0.34,  -0.61, 0.41,  0.07,  -0.02, 1.42,  -0.62, 0.81,
   2058        0.08,  0.81,  -0.30, 1.17,  -0.65, -0.44, 0.92,  1.26,  -1.29, 1.35,
   2059        0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
   2060        -0.79, 1.41,  1.21,  1.05});
   2061   TF_ASSERT_OK_AND_ASSIGN(auto input_data,
   2062                           client_->TransferToServer(*input_literal));
   2063 
   2064   auto input = builder.Parameter(0, input_literal->shape(), "input");
   2065   builder.Tanh(input);
   2066 
   2067   ComputeAndCompareR1<float>(
   2068       &builder,
   2069       {0.77009583,  -0.30665702, 0.69070244,  0.71401149,  0.84400684,
   2070        -0.71985596, -0.45764771, 0.66664988,  -0.58278900, 0.16050975,
   2071        -0.06770509, 0.36843640,  -0.38476998, 0.04018109,  0.87562293,
   2072        0.84788644,  0.38603750,  0.57294142,  -0.79140943, 0.31032649,
   2073        -0.89590985, -0.64770776, -0.79625875, 0.72234446,  -0.77389336,
   2074        -0.28871772, -0.80428445, -0.82541436, 0.90456349,  -0.68856895,
   2075        0.03877772,  0.76877952,  0.32561871,  -0.54546672, 0.39072621,
   2076        0.07273290,  -0.01924866, 0.88924897,  -0.55283129, 0.67183107,
   2077        0.08006320,  0.66944766,  -0.29068485, 0.82573754,  -0.57170743,
   2078        -0.41581789, 0.72739530,  0.85025692,  -0.85931867, 0.87357593,
   2079        0.07782833,  -0.84597743, -0.72748238, 0.45396307,  0.82449573,
   2080        -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
   2081        -0.65565848, 0.88789743,  0.83566397,  0.78287679},
   2082       {input_data.get()},
   2083       // The error spec is unusually high here to account for the fact that we
   2084       // use a rational interpolant to approximate tanh.
   2085       ErrorSpec(0.004, 0.004));
   2086 }
   2087 
   2088 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
   2089   // The input tensor is large enough to exercise the vectorized exp
   2090   // implementation on XLA CPU.
   2091   ComputationBuilder builder(client_, TestName());
   2092 
   2093   // Just to help make sense of the scales here -- exp(89) saturates float32 and
   2094   // exp(-10) is smaller than our error spec.
   2095   std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
   2096       {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
   2097        -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
   2098        -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
   2099        -16.3,  -15.2,  -14.1,  -13.0,  -11.9,  -10.8,  -9.7,  -8.6,   -7.5,
   2100        -6.4,   -5.3,   -4.2,   -3.1,   -2.0,   -0.9,   0.2,   1.3,    2.4,
   2101        3.5,    4.6,    5.7,    6.8,    7.9,    9.0,    10.1,  11.2,   12.3,
   2102        13.4,   14.5,   15.6,   16.7,   17.8,   18.9,   20.0,  21.1,   22.2,
   2103        23.3,   24.4,   25.5,   26.6,   27.7,   28.8,   29.9,  31.0,   32.1,
   2104        68.4,   69.5,   70.6,   71.7,   72.8,   73.9,   75.0,  76.1,   77.2,
   2105        78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
   2106        86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
   2107   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
   2108                           client_->TransferToServer(*input_literal));
   2109 
   2110   auto input = builder.Parameter(0, input_literal->shape(), "input");
   2111   builder.Exp(input);
   2112 
   2113   std::vector<float> expected_result;
   2114   int64 input_size = input_literal->shape().dimensions(0);
   2115   expected_result.reserve(input_size);
   2116   for (int64 i = 0; i < input_size; i++) {
   2117     expected_result.push_back(std::exp(input_literal->Get<float>({i})));
   2118   }
   2119 
   2120   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
   2121                              error_spec_);
   2122 }
   2123 
   2124 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
   2125   // The input tensor is large enough to exercise the vectorized exp
   2126   // implementation on XLA CPU.
   2127   ComputationBuilder builder(client_, TestName());
   2128 
   2129   std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
   2130       {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
   2131        -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
   2132        198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
   2133        1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
   2134        1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
   2135        1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
   2136        1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
   2137        1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
   2138        2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
   2139        1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
   2140        1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
   2141        1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
   2142        1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
   2143        1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
   2144        1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
   2145   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
   2146                           client_->TransferToServer(*input_literal));
   2147 
   2148   auto input = builder.Parameter(0, input_literal->shape(), "input");
   2149   builder.Log(input);
   2150 
   2151   std::vector<float> expected_result;
   2152   int64 input_size = input_literal->shape().dimensions(0);
   2153   expected_result.reserve(input_size);
   2154   for (int64 i = 0; i < input_size; i++) {
   2155     expected_result.push_back(std::log(input_literal->Get<float>({i})));
   2156   }
   2157 
   2158   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
   2159                              error_spec_);
   2160 }
   2161 
   2162 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
   2163   // a ------ (add) --------- (add)
   2164   //         /               /
   2165   // b -----/               /
   2166   // c---------------------/
   2167   ComputationBuilder builder(client_, TestName());
   2168 
   2169   auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
   2170   auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
   2171   auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
   2172 
   2173   auto add = builder.Add(a, b);
   2174   auto add2 = builder.Add(add, c);
   2175 
   2176   ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
   2177                              error_spec_);
   2178 }
   2179 
   2180 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
   2181   // b ------ (add) --------- (add)
   2182   //         /               /
   2183   // c -----/               /
   2184   // a---------------------/
   2185   ComputationBuilder builder(client_, TestName());
   2186 
   2187   auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
   2188   auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
   2189   auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
   2190 
   2191   auto add = builder.Add(b, c);
   2192   auto add2 = builder.Add(a, add);
   2193 
   2194   ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
   2195                              error_spec_);
   2196 }
   2197 
   2198 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
   2199   // a ----- (neg) ----- (add)
   2200   //                    /
   2201   // b ----- (neg) ----/
   2202   ComputationBuilder builder(client_, TestName());
   2203 
   2204   auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
   2205   auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
   2206 
   2207   auto neg_a = builder.Neg(a);
   2208   auto neg_b = builder.Neg(b);
   2209   auto result = builder.Add(neg_a, neg_b);
   2210 
   2211   ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
   2212                              error_spec_);
   2213 }
   2214 
   2215 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
   2216   // a ------ (add) ------------\
   2217   //         /                   \
   2218   // b -----/                    (add)
   2219   //                             /
   2220   // c ------ (add) ------------/
   2221   //         /
   2222   // d -----/
   2223   ComputationBuilder builder(client_, TestName());
   2224 
   2225   auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
   2226   auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
   2227   auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
   2228   auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f});
   2229 
   2230   auto add_ab = builder.Add(a, b);
   2231   auto add_cd = builder.Add(c, d);
   2232   auto add_all = builder.Add(add_ab, add_cd);
   2233 
   2234   ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
   2235                              error_spec_);
   2236 }
   2237 
   2238 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
   2239   ComputationBuilder builder(client_, TestName());
   2240   auto a =
   2241       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2242   auto b =
   2243       builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
   2244   auto add = builder.Add(a, b);
   2245 
   2246   Array2D<float> expected_array(
   2247       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
   2248   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2249 }
   2250 
   2251 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
   2252   // Add a scalar + matrix.
   2253   ComputationBuilder builder(client_, TestName());
   2254   auto a =
   2255       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2256   auto scalar = builder.ConstantR0<float>(3.0f);
   2257   auto add = builder.Add(scalar, a);
   2258 
   2259   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
   2260   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2261 }
   2262 
   2263 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
   2264   // Add a matrix + scalar.
   2265   ComputationBuilder builder(client_, TestName());
   2266   auto a =
   2267       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2268   auto scalar = builder.ConstantR0<float>(3.0f);
   2269   auto add = builder.Add(a, scalar);
   2270 
   2271   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
   2272   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2273 }
   2274 
   2275 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
   2276   // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
   2277   // only dim 0 of the matrix.
   2278   ComputationBuilder builder(client_, TestName());
   2279   auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f});
   2280   // clang-format off
   2281   auto m = builder.ConstantR2<float>({
   2282     {-2.5f, 3.14f, 1.0f},
   2283     {2.25f, -10.0f, 3.33f}});
   2284   // clang-format on
   2285   auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
   2286   Array2D<float> expected_array(
   2287       {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
   2288   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2289 }
   2290 
   2291 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
   2292   // Test broadcasting in Eq comparison.
   2293   ComputationBuilder builder(client_, TestName());
   2294   auto v = builder.ConstantR1<int32>({42, 73});
   2295   auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
   2296 
   2297   // This test exercises both possible broadcast dimensions for a vector/matrix
   2298   // comparison.
   2299   auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
   2300   auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
   2301   auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
   2302 
   2303   auto expected = Literal::MakeTuple(
   2304       {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(),
   2305        Literal::CreateR2<bool>({{true, false}, {false, false}}).get()});
   2306   ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
   2307 }
   2308 
   2309 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
   2310   // Test broadcasting in Ne comparison.
   2311   ComputationBuilder builder(client_, TestName());
   2312   auto v = builder.ConstantR1<int32>({42, 73});
   2313   auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
   2314   auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1});
   2315 
   2316   const string expected = R"(pred[2,2] {
   2317   { 00 },
   2318   { 01 }
   2319 })";
   2320   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2321 }
   2322 
   2323 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
   2324   // Test broadcasting in Ge comparison.
   2325   ComputationBuilder builder(client_, TestName());
   2326   auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
   2327   auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
   2328   auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1});
   2329 
   2330   const string expected = R"(pred[2,4] {
   2331   { 1100 },
   2332   { 0001 }
   2333 })";
   2334   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2335 }
   2336 
   2337 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
   2338   // Test broadcasting in Gt comparison.
   2339   ComputationBuilder builder(client_, TestName());
   2340   auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
   2341   auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
   2342   auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1});
   2343 
   2344   const string expected = R"(pred[2,4] {
   2345   { 0100 },
   2346   { 0000 }
   2347 })";
   2348   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2349 }
   2350 
   2351 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
   2352   // Test broadcasting in Le comparison.
   2353   ComputationBuilder builder(client_, TestName());
   2354   auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
   2355   auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
   2356   auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1});
   2357 
   2358   const string expected = R"(pred[2,4] {
   2359   { 1011 },
   2360   { 1111 }
   2361 })";
   2362   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2363 }
   2364 
   2365 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
   2366   // Test broadcasting in Lt comparison.
   2367   ComputationBuilder builder(client_, TestName());
   2368   auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
   2369   auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
   2370   auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1});
   2371 
   2372   const string expected = R"(pred[2,4] {
   2373   { 0011 },
   2374   { 1110 }
   2375 })";
   2376   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2377 }
   2378 
   2379 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
   2380   // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
   2381   // arguments is reversed.
   2382   ComputationBuilder builder(client_, TestName());
   2383   auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
   2384   auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f});
   2385   auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1});
   2386   Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
   2387   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2388 }
   2389 
   2390 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
   2391   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
   2392   ComputationBuilder builder(client_, TestName());
   2393   // m's shape in XLA notation is {3, 2}
   2394   // md's shape in XLA notation is {3, 1}
   2395   // The result has shape {3, 2}, where md is broadcast over m
   2396   auto m =
   2397       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2398   auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}});
   2399   auto add = builder.Add(m, md);
   2400   Array2D<float> expected_array(
   2401       {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
   2402   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2403 }
   2404 
   2405 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
   2406   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
   2407   ComputationBuilder builder(client_, TestName());
   2408   // m's shape in XLA notation is {3, 2}
   2409   // md's shape in XLA notation is {1, 2}
   2410   // The result has shape {3, 2}, where md is broadcast over m
   2411   auto m =
   2412       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2413   auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}});
   2414   auto add = builder.Add(m, md);
   2415   Array2D<float> expected_array(
   2416       {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
   2417   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2418 }
   2419 
   2420 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
   2421   // Tests broadcasting for two degenerate arrays. This kind of broadcasting
   2422   // effectively creates an "outer product" operation.
   2423   // This is taken from the Numpy docs example at:
   2424   // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
   2425   ComputationBuilder builder(client_, TestName());
   2426   // a's shape in XLA notation is {1, 4}
   2427   // b's shape in XLA notation is {3, 1}
   2428   // The result has shape {3, 4}.
   2429   auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}});
   2430   auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
   2431   auto add = builder.Add(a, b);
   2432   Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
   2433                                  {11.0f, 12.0f, 13.0f},
   2434                                  {21.0f, 22.0f, 23.0f},
   2435                                  {31.0f, 32.0f, 33.0f}});
   2436   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2437 }
   2438 
   2439 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
   2440   // Add together a (2,2) array and a (2) array, using dimension 0 for
   2441   // broadcasting (though there are two ways to broadcast these shapes).
   2442   ComputationBuilder builder(client_, TestName());
   2443   auto v = builder.ConstantR1<float>({20.0f, 40.0f});
   2444   auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
   2445   auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
   2446   Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
   2447   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2448 }
   2449 
   2450 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
   2451   // Add together a (2,2) array and a (2) array, using dimension 1 for
   2452   // broadcasting (though there are two ways to broadcast these shapes).
   2453   ComputationBuilder builder(client_, TestName());
   2454   auto v = builder.ConstantR1<float>({20.0f, 40.0f});
   2455   auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
   2456   auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0});
   2457   Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
   2458   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2459 }
   2460 
   2461 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
   2462   // Binary add of two R3s together
   2463   ComputationBuilder builder(client_, TestName());
   2464   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
   2465                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
   2466   auto a = builder.ConstantR3FromArray3D<float>(a_3d);
   2467 
   2468   Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
   2469                        {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
   2470   auto b = builder.ConstantR3FromArray3D<float>(b_3d);
   2471   auto add = builder.Add(a, b);
   2472 
   2473   Array3D<float> expected_3d(
   2474       {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
   2475        {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
   2476   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
   2477 }
   2478 
   2479 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
   2480   // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
   2481   // broadcasting (though there are two ways to broadcast these shapes).
   2482   ComputationBuilder builder(client_, TestName());
   2483   // clang-format off
   2484   Array3D<float> a_3d({
   2485     {{1.0f, 2.0f},
   2486      {3.0f, 4.0f},
   2487      {5.0f, 6.0f}},
   2488     {{7.0f, 8.0f},
   2489      {9.0f, 10.0f},
   2490      {11.0f, 12.0f}},
   2491   });
   2492   // clang-format on
   2493   auto a = builder.ConstantR3FromArray3D<float>(a_3d);
   2494   auto v = builder.ConstantR1<float>({10.0f, 20.0f});
   2495   auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2});
   2496 
   2497   Array3D<float> expected_3d(
   2498       {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
   2499        {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
   2500   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
   2501 }
   2502 
   2503 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
   2504   // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
   2505   // broadcasting (though there are two ways to broadcast these shapes).
   2506   ComputationBuilder builder(client_, TestName());
   2507   // clang-format off
   2508   Array3D<float> a_3d({
   2509     {{1.0f, 2.0f},
   2510      {3.0f, 4.0f},
   2511      {5.0f, 6.0f}},
   2512     {{7.0f, 8.0f},
   2513      {9.0f, 10.0f},
   2514      {11.0f, 12.0f}},
   2515   });
   2516   // clang-format on
   2517   auto a = builder.ConstantR3FromArray3D<float>(a_3d);
   2518   auto v = builder.ConstantR1<float>({10.0f, 20.0f});
   2519   auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0});
   2520 
   2521   // clang-format off
   2522   Array3D<float> expected_3d({
   2523     {{11.0f, 12.0f},
   2524      {13.0f, 14.0f},
   2525      {15.0f, 16.0f}},
   2526     {{27.0f, 28.0f},
   2527      {29.0f, 30.0f},
   2528      {31.0f, 32.0f}},
   2529   });
   2530   // clang-format on
   2531   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
   2532 }
   2533 
   2534 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
   2535   // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
   2536   // for broadcasting.
   2537   ComputationBuilder builder(client_, TestName());
   2538   // clang-format off
   2539   Array3D<float> a_3d({
   2540     {{1.0f, 2.0f},
   2541      {3.0f, 4.0f},
   2542      {5.0f, 6.0f}},
   2543     {{7.0f, 8.0f},
   2544      {9.0f, 10.0f},
   2545      {11.0f, 12.0f}},
   2546   });
   2547   auto a = builder.ConstantR3FromArray3D<float>(a_3d);
   2548   auto m = builder.ConstantR2<float>({
   2549     {10.0f, 20.0f, 30.0f},
   2550     {40.0f, 50.0f, 60.0f},
   2551   });
   2552   auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1});
   2553 
   2554   Array3D<float> expected_3d({
   2555     {{11.0f, 12.0f},
   2556      {23.0f, 24.0f},
   2557      {35.0f, 36.0f}},
   2558     {{47.0f, 48.0f},
   2559      {59.0f, 60.0f},
   2560      {71.0f, 72.0f}},
   2561   });
   2562   // clang-format on
   2563   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
   2564 }
   2565 
   2566 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
   2567   // Comparison between two 3D arrays of compatible shapes:
   2568   // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
   2569   ComputationBuilder builder(client_, TestName());
   2570   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
   2571                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
   2572   auto a = builder.ConstantR3FromArray3D<float>(a_3d);
   2573 
   2574   Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
   2575   auto b = builder.ConstantR3FromArray3D<float>(b_3d);
   2576 
   2577   auto compare = builder.Gt(a, b);
   2578 
   2579   Array3D<int> expected_3d(
   2580       {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
   2581   const string expected = R"(pred[2,3,2] {
   2582 { { 01 },
   2583   { 00 },
   2584   { 00 } },
   2585 { { 01 },
   2586   { 10 },
   2587   { 01 } }
   2588 })";
   2589   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
   2590 }
   2591 
   2592 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
   2593   ComputationBuilder builder(client_, TestName());
   2594 
   2595   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
   2596   std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
   2597   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
   2598   float value = 0.0;
   2599   for (int64 p = 0; p < 2; ++p) {
   2600     for (int64 z = 0; z < 3; ++z) {
   2601       for (int64 y = 0; y < 4; ++y) {
   2602         for (int64 x = 0; x < 5; ++x) {
   2603           (*operand_a_4d)(p, z, y, x) = value;
   2604           (*operand_b_4d)(p, z, y, x) = 2.0 * value;
   2605           (*expected_4d)(p, z, y, x) = 3.0 * value;
   2606           value += 0.1;
   2607         }
   2608       }
   2609     }
   2610   }
   2611 
   2612   auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
   2613   auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d);
   2614   auto add = builder.Add(a, b);
   2615 
   2616   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
   2617 }
   2618 
   2619 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
   2620   ComputationBuilder builder(client_, TestName());
   2621 
   2622   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
   2623   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
   2624   std::vector<float> operand_b_1d(3);
   2625   std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
   2626 
   2627   float value = 0.0;
   2628   for (int64 p = 0; p < 2; ++p) {
   2629     for (int64 z = 0; z < 3; ++z) {
   2630       for (int64 y = 0; y < 4; ++y) {
   2631         for (int64 x = 0; x < 5; ++x) {
   2632           (*operand_a_4d)(p, z, y, x) = value;
   2633           (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
   2634           value += 0.1;
   2635         }
   2636       }
   2637     }
   2638   }
   2639 
   2640   auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
   2641   auto b = builder.ConstantR1<float>(operand_b_1d);
   2642   auto add = builder.Add(a, b, {1});
   2643 
   2644   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
   2645 }
   2646 
   2647 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
   2648   constexpr int d0 = 16;
   2649   constexpr int d1 = 16;
   2650   constexpr int d2 = 2;
   2651   constexpr int d3 = 2;
   2652   Array4D<float> r4(d0, d1, d2, d3);
   2653   r4.Fill(1.0);
   2654   std::vector<float> r1(d1);
   2655   std::iota(r1.begin(), r1.end(), 1.0);
   2656 
   2657   ComputationBuilder builder(client_, TestName());
   2658   std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
   2659       r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
   2660   auto a = builder.ConstantLiteral(*a_literal);
   2661   auto b = builder.ConstantR1<float>(r1);
   2662   builder.Add(a, b, {1});
   2663 
   2664   for (int i0 = 0; i0 < d0; ++i0) {
   2665     for (int i1 = 0; i1 < d1; ++i1) {
   2666       for (int i2 = 0; i2 < d2; ++i2) {
   2667         for (int i3 = 0; i3 < d3; ++i3) {
   2668           r4(i0, i1, i2, i3) += r1[i1];
   2669         }
   2670       }
   2671     }
   2672   }
   2673   ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
   2674 }
   2675 
   2676 // Show that we can't add two opaques.
   2677 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
   2678   ComputationBuilder builder(client_, TestName());
   2679   auto shape = ShapeUtil::MakeOpaqueShape();
   2680   auto x = builder.Parameter(0, shape, "x");
   2681   auto concatenated = builder.Add(x, x);
   2682   StatusOr<Computation> computation_status = builder.Build();
   2683   ASSERT_FALSE(computation_status.ok());
   2684   EXPECT_THAT(computation_status.status().ToString(),
   2685               ::testing::ContainsRegex(
   2686                   "Expected non-opaque argument for lhs of binary operation"));
   2687 }
   2688 
   2689 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
   2690   ComputationBuilder builder(client_, TestName());
   2691   auto a =
   2692       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2693   auto b =
   2694       builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
   2695   auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1});
   2696 
   2697   Array2D<float> expected_array(
   2698       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
   2699   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
   2700 }
   2701 
   2702 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
   2703   ComputationBuilder builder(client_, TestName());
   2704   auto a =
   2705       builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
   2706   auto b =
   2707       builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
   2708   auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0});
   2709 
   2710   StatusOr<Computation> computation_status = builder.Build();
   2711   ASSERT_FALSE(computation_status.ok());
   2712   EXPECT_THAT(computation_status.status().error_message(),
   2713               ::testing::ContainsRegex("must.*be the identity"));
   2714 }
   2715 
   2716 // Regression test for b/31927799. "slice - y" is fused and requires implicit
   2717 // broadcast.
   2718 XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
   2719   ComputationBuilder builder(client_, TestName());
   2720   auto x_literal = Literal::CreateR1<float>({1, 2, 3});
   2721   auto y_literal = Literal::CreateR1<float>({4, 5});
   2722   auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
   2723   auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
   2724 
   2725   auto x = builder.Parameter(0, x_literal->shape(), "x");
   2726   auto y = builder.Parameter(1, y_literal->shape(), "y");
   2727   auto slice = builder.Slice(x, {1}, {2}, {1});
   2728   builder.Sub(slice, y);
   2729 
   2730   ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
   2731                              error_spec_);
   2732 }
   2733 
   2734 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
   2735                         ArrayElementwiseOpTestParamCount,
   2736                         ::testing::Values(127, 128, 129, 17 * 4096));
   2737 
   2738 }  // namespace
   2739 }  // namespace xla
   2740