Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     16 
     17 #include <initializer_list>
     18 #include <memory>
     19 #include <string>
     20 #include <tuple>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "absl/memory/memory.h"
     25 #include "absl/strings/str_format.h"
     26 #include "tensorflow/compiler/xla/client/xla_builder.h"
     27 #include "tensorflow/compiler/xla/literal.h"
     28 #include "tensorflow/compiler/xla/reference_util.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
     31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/status.h"
     34 #include "tensorflow/compiler/xla/status_macros.h"
     35 #include "tensorflow/compiler/xla/statusor.h"
     36 #include "tensorflow/compiler/xla/test.h"
     37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     39 #include "tensorflow/compiler/xla/tests/test_utils.h"
     40 #include "tensorflow/compiler/xla/types.h"
     41 #include "tensorflow/compiler/xla/util.h"
     42 #include "tensorflow/compiler/xla/xla_data.pb.h"
     43 #include "tensorflow/core/lib/core/status.h"
     44 #include "tensorflow/core/lib/core/status_test_util.h"
     45 #include "tensorflow/core/platform/test.h"
     46 #include "tensorflow/core/platform/test_benchmark.h"
     47 #include "tensorflow/core/platform/types.h"
     48 
     49 namespace xla {
     50 namespace {
     51 
     52 static std::array<bool, 2> use_bf16_params{true, false};
     53 
     54 // Test fixture for the HloEvaluator.
     55 //
     56 // In bf16 mode, all f32 shapes are converted to bf16 before running.
     57 class HloEvaluatorTest : public HloTestBase {
     58  public:
     59   HloEvaluatorTest() : use_bfloat16_(false) {}
     60 
     61   StatusOr<Literal> Evaluate(
     62       absl::Span<const Literal* const> arg_literals = {}) {
     63     if (use_bfloat16_) {
     64       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
     65     }
     66     return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
     67   }
     68 
     69   // Evaluate function that takes in a local module instead of using m_
     70   // that is in HloTestBase. Once m_ in HloTestBase is
     71   // removed, this should be the default Evaluate function.
     72   Literal EvaluateWithModule(
     73       HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
     74     if (use_bfloat16_) {
     75       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
     76     }
     77     return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
     78         .ConsumeValueOrDie();
     79   }
     80 
     81   void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
     82                    float aabs = 0) {
     83     HloComputation::Builder b(TestName());
     84     auto c1 =
     85         b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
     86     b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
     87     m_->AddEntryComputation(b.Build());
     88 
     89     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
     90 
     91     auto element_type = expected.shape().element_type();
     92     if (element_type == F32 || element_type == F64) {
     93       ErrorSpec error(aabs);
     94       EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
     95     } else {
     96       EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
     97     }
     98   }
     99 
    100   void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
    101                     Literal rhs) {
    102     HloComputation::Builder b(TestName());
    103     auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
    104     auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
    105     b.AddInstruction(
    106         HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
    107     m_->AddEntryComputation(b.Build());
    108 
    109     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    110 
    111     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    112   }
    113 
    114   void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
    115                      Literal src1, Literal src2) {
    116     HloComputation::Builder b(TestName());
    117     auto operand0 =
    118         b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
    119     auto operand1 =
    120         b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
    121     auto operand2 =
    122         b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
    123     b.AddInstruction(HloInstruction::CreateTernary(
    124         expected.shape(), opcode, operand0, operand1, operand2));
    125     m_->AddEntryComputation(b.Build());
    126 
    127     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    128 
    129     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    130   }
    131 
    132  protected:
    133   explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {}
    134   HloEvaluator evaluator_;
    135 
    136   const bool use_bfloat16_;
    137   std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
    138 };
    139 
    140 // Lets you write TEST_Ps that run twice, once with and once without bf16.
    141 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
    142                              public HloEvaluatorTest {
    143  protected:
    144   HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
    145 };
    146 
    147 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
    148                          ::testing::ValuesIn(use_bf16_params));
    149 
    150 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
    151 // with 3 operands.
    152 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
    153   auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
    154   auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
    155   auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
    156 
    157   Shape shape = low.shape();
    158   HloComputation::Builder b(TestName());
    159   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
    160   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
    161   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
    162   b.AddInstruction(
    163       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
    164   m_->AddEntryComputation(b.Build());
    165 
    166   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    167 
    168   auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
    169 
    170   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    171 }
    172 
    173 // Verifies that clamping of int64 does not cause loss of precision
    174 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
    175   auto ones = [](int bits) { return (int64{1} << bits) - 1; };
    176 
    177   auto low =
    178       LiteralUtil::CreateR2<int64>({{0, ones(54)}, {ones(54), ones(58)}});
    179   auto value = LiteralUtil::CreateR2<int64>({{0, ones(56)}, {0, ones(58)}});
    180   auto high = LiteralUtil::CreateR2<int64>(
    181       {{ones(54), ones(55)}, {ones(56), ones(58)}});
    182 
    183   Shape shape = low.shape();
    184   HloComputation::Builder b(TestName());
    185   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
    186   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
    187   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
    188   b.AddInstruction(
    189       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
    190   m_->AddEntryComputation(b.Build());
    191 
    192   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    193 
    194   auto expected =
    195       LiteralUtil::CreateR2<int64>({{0, ones(55)}, {ones(54), ones(58)}});
    196 
    197   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    198 }
    199 
    200 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
    201   auto low = LiteralUtil::CreateR0<float>(0.f);
    202   auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
    203   auto high = LiteralUtil::CreateR0<float>(1.f);
    204 
    205   Shape shape = value.shape();
    206   HloComputation::Builder b(TestName());
    207   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
    208   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
    209   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
    210   b.AddInstruction(
    211       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
    212   m_->AddEntryComputation(b.Build());
    213 
    214   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    215 
    216   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
    217 
    218   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    219 }
    220 
    221 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
    222 // with 3 operands.
    223 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
    224   auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
    225   auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
    226   auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
    227 
    228   Shape shape = on_true.shape();
    229   HloComputation::Builder b(TestName());
    230   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
    231   auto c2 =
    232       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
    233   auto c3 =
    234       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
    235   b.AddInstruction(
    236       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
    237   m_->AddEntryComputation(b.Build());
    238 
    239   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
    240 
    241   auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
    242 
    243   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    244 }
    245 
    246 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    247 // element-wise addition with 2 operands.
    248 TEST_F(HloEvaluatorTest, DoesAdd) {
    249   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    250   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    251   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
    252   TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
    253                std::move(rhs));
    254 }
    255 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    256 // element-wise and with 2 operands.
    257 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
    258   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    259   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    260   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
    261   TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
    262                std::move(rhs));
    263 }
    264 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    265 // element-wise or with 2 operands.
    266 TEST_F(HloEvaluatorTest, DoesOr) {
    267   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    268   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    269   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
    270   TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
    271                std::move(rhs));
    272 }
    273 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    274 // element-wise or with 2 operands.
    275 TEST_F(HloEvaluatorTest, DoesXor) {
    276   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    277   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    278   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
    279   TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
    280                std::move(rhs));
    281 }
    282 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    283 // element-wise multiply with 2 operands.
    284 TEST_F(HloEvaluatorTest, DoesMultiply) {
    285   auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
    286   auto rhs = LiteralUtil::CreateR2<int32>(
    287       {{std::numeric_limits<int32>::min(), 4}, {4, 4}});
    288   auto expected = LiteralUtil::CreateR2<int32>(
    289       {{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
    290   TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
    291                std::move(rhs));
    292 }
    293 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    294 // element-wise divide with 2 operands.
    295 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
    296   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    297   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    298   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
    299   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
    300                std::move(rhs));
    301 }
    302 
    303 TEST_F(HloEvaluatorTest, DoesClampS64) {
    304   auto low = LiteralUtil::CreateR1<int64>(
    305       {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
    306   auto value = LiteralUtil::CreateR1<int64>(
    307       {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
    308   auto high = LiteralUtil::CreateR1<int64>(
    309       {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
    310   auto expected = LiteralUtil::CreateR1<int64>(
    311       {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
    312   TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
    313                 std::move(value), std::move(high));
    314 }
    315 
    316 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
    317   auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
    318   auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
    319   auto expected =
    320       LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
    321   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
    322                std::move(rhs));
    323 }
    324 
    325 // Verifies that HloEvaluator evaluates a HLO instruction that performs
    326 // element-wise abs op with 1 operand.
    327 TEST_F(HloEvaluatorTest, DoesAbsR2) {
    328   auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
    329   auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
    330   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
    331 }
    332 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
    333   auto operand = LiteralUtil::CreateR0<float>(-1.0f);
    334   auto expected = LiteralUtil::CreateR0<float>(1.0f);
    335   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
    336 }
    337 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
    338   auto operand = LiteralUtil::CreateR1<float>({});
    339   auto expected = LiteralUtil::CreateR1<float>({});
    340   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
    341 }
    342 TEST_F(HloEvaluatorTest, DoesNegateR2) {
    343   auto operand = LiteralUtil::CreateR2<int32>(
    344       {{0, std::numeric_limits<int32>::min()}, {-1, 4}});
    345   auto expected = LiteralUtil::CreateR2<int32>(
    346       {{0, std::numeric_limits<int>::min()}, {1, -4}});
    347   TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
    348 }
    349 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
    350   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
    351   auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
    352   TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
    353               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
    354 }
    355 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
    356   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
    357   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
    358   TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
    359               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
    360 }
    361 TEST_F(HloEvaluatorTest, DoesNotR2) {
    362   auto operand =
    363       LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
    364                                     {-1, std::numeric_limits<int>::max()}});
    365   auto expected =
    366       LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
    367                                     {0, std::numeric_limits<int>::min()}});
    368   TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
    369 }
    370 
    371 TEST_F(HloEvaluatorTest, DoesRealC128) {
    372   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
    373   auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
    374   TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
    375 }
    376 
    377 TEST_F(HloEvaluatorTest, DoesImagC128) {
    378   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
    379   auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
    380   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
    381 }
    382 
    383 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
    384 // constant operands.
    385 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
    386   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
    387   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
    388   auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
    389   std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
    390 
    391   Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
    392 
    393   HloComputation::Builder b(TestName());
    394   auto param_lhs =
    395       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
    396   auto param_rhs =
    397       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
    398   auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
    399       shape, HloOpcode::kAdd, param_lhs, param_rhs));
    400 
    401   auto param_rhs2 =
    402       b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
    403   b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
    404                                                 lhs_instruction, param_rhs2));
    405   m_->AddEntryComputation(b.Build());
    406 
    407   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
    408 
    409   auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
    410 
    411   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    412 }
    413 
    414 // Verifies Reshape operation is correctly evaluated.
    415 TEST_F(HloEvaluatorTest, DoesReshape) {
    416   HloComputation::Builder b(TestName());
    417   const int64 dimensions[] = {11, 8, 7, 5, 9};
    418   TF_ASSERT_OK_AND_ASSIGN(auto literal,
    419                           LiteralUtil::CreateRandomLiteral<F32>(
    420                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
    421   auto literal_clone = literal.Clone();
    422   HloInstruction* literal_instruction =
    423       b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
    424 
    425   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
    426   const int64 permutation[] = {1, 2, 0, 4, 3};
    427   b.AddInstruction(
    428       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
    429   m_->AddEntryComputation(b.Build());
    430 
    431   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
    432 
    433   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
    434   result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
    435     std::vector<int64> rindexes = Permute(permutation, indices);
    436     EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
    437   });
    438 }
    439 
    440 // Verifies Broadcast operation is correctly evaluated.
    441 TEST_F(HloEvaluatorTest, DoesBroadcast) {
    442   HloComputation::Builder b(TestName());
    443   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
    444   auto output_literal = LiteralUtil::CreateR3<int32>(
    445       {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
    446   HloInstruction* literal_instruction = b.AddInstruction(
    447       HloInstruction::CreateConstant(std::move(input_literal)));
    448   b.AddInstruction(HloInstruction::CreateBroadcast(
    449       output_literal.shape(), literal_instruction, {1, 2}));
    450   m_->AddEntryComputation(b.Build());
    451 
    452   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
    453 
    454   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
    455 }
    456 
    457 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
    458   HloComputation::Builder b(TestName());
    459   auto input_literal = LiteralUtil::CreateR0<int32>(111);
    460   auto output_literal = LiteralUtil::CreateR2<int32>(
    461       {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
    462 
    463   HloInstruction* literal_instruction = b.AddInstruction(
    464       HloInstruction::CreateConstant(std::move(input_literal)));
    465   // Broadcast dimension should be empty in the case of scalars.
    466   b.AddInstruction(HloInstruction::CreateBroadcast(
    467       output_literal.shape(), literal_instruction,
    468       /*broadcast_dimensions=*/{}));
    469   m_->AddEntryComputation(b.Build());
    470 
    471   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
    472 
    473   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
    474 }
    475 
    476 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
    477   HloComputation::Builder b(TestName());
    478 
    479   HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
    480       LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
    481   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
    482       LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
    483 
    484   std::vector<HloInstruction*> operands = {operand1, operand2};
    485 
    486   Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
    487   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
    488 
    489   m_->AddEntryComputation(b.Build());
    490 
    491   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    492 
    493   auto expected = LiteralUtil::CreateR2<int64>(
    494       {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
    495   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    496 }
    497 
    498 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
    499   HloComputation::Builder b(TestName());
    500 
    501   HloInstruction* operand1 = b.AddInstruction(
    502       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
    503   HloInstruction* operand2 = b.AddInstruction(
    504       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
    505 
    506   std::vector<HloInstruction*> operands = {operand1, operand2};
    507 
    508   Shape shape = ShapeUtil::MakeShape(S64, {2});
    509   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
    510 
    511   m_->AddEntryComputation(b.Build());
    512 
    513   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    514 
    515   auto expected = LiteralUtil::CreateR1<int64>({100, 200});
    516   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    517 }
    518 
    519 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
    520   HloComputation::Builder b(TestName());
    521 
    522   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
    523   auto expected =
    524       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
    525   ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
    526                                                expected.shape()));
    527 
    528   HloInstruction* constant = b.AddInstruction(
    529       HloInstruction::CreateConstant(std::move(input_literal)));
    530   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
    531   m_->AddEntryComputation(b.Build());
    532 
    533   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    534 
    535   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
    536 }
    537 
    538 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
    539   HloComputation::Builder b(TestName());
    540 
    541   auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
    542       {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
    543   auto expected = LiteralUtil::CreateR2WithLayout<float>(
    544       {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
    545   ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
    546                                                 expected.shape()));
    547 
    548   HloInstruction* constant = b.AddInstruction(
    549       HloInstruction::CreateConstant(std::move(input_literal)));
    550   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
    551   m_->AddEntryComputation(b.Build());
    552 
    553   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    554 
    555   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
    556 }
    557 
    558 PaddingConfig CreatePaddingConfig(
    559     std::initializer_list<std::array<int64, 3>> padding_dimensions) {
    560   PaddingConfig padding_config;
    561 
    562   for (auto& paddings_per_dim : padding_dimensions) {
    563     auto dimension = padding_config.add_dimensions();
    564     dimension->set_edge_padding_low(paddings_per_dim[0]);
    565     dimension->set_edge_padding_high(paddings_per_dim[1]);
    566     dimension->set_interior_padding(paddings_per_dim[2]);
    567   }
    568   return padding_config;
    569 }
    570 
    571 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
    572   auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
    573   HloComputation::Builder b(TestName());
    574   auto operand_instruction =
    575       b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
    576 
    577   constexpr int32 kPadValue = 10;
    578   auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
    579   auto padding_value_instruction =
    580       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
    581 
    582   auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
    583   Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
    584   b.AddInstruction(HloInstruction::CreatePad(
    585       shape, operand_instruction, padding_value_instruction, padding_config));
    586   m_->AddEntryComputation(b.Build());
    587 
    588   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    589 
    590   auto expected = LiteralUtil::CreateR2<int32>(
    591       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
    592 
    593   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    594 }
    595 
    596 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
    597   HloComputation::Builder b(TestName());
    598 
    599   Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
    600   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
    601   HloInstruction* input_instruction =
    602       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
    603   constexpr float kPadValue = 1.5;
    604   auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
    605   HloInstruction* pad_instruction =
    606       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
    607 
    608   Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
    609   auto r4_padding_on_dim0_dim1 =
    610       CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
    611   b.AddInstruction(HloInstruction::CreatePad(
    612       shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
    613   m_->AddEntryComputation(b.Build());
    614 
    615   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    616 
    617   auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
    618   expected_array->Fill(kPadValue);
    619   (*expected_array)(1, 0, 0, 0) = 1.0f;
    620   (*expected_array)(1, 2, 0, 0) = 2.0f;
    621   (*expected_array)(4, 0, 0, 0) = 3.0f;
    622   (*expected_array)(4, 2, 0, 0) = 4.0f;
    623   (*expected_array)(7, 0, 0, 0) = 5.0f;
    624   (*expected_array)(7, 2, 0, 0) = 6.0f;
    625 
    626   auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
    627 
    628   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    629 }
    630 
    631 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
    632   HloComputation::Builder b(TestName());
    633 
    634   // input_array:
    635   // f32[4,3] {
    636   //  { 1, 2, 3 },
    637   //  { 5, 6, 7 },
    638   //  { 9, 10, 11 },
    639   //  { 13, 14, 15 },
    640   // }
    641   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
    642   input_array->FillUnique(1.0f);
    643   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
    644   HloInstruction* input_instruction =
    645       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
    646 
    647   auto pad_value_instruction = b.AddInstruction(
    648       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
    649 
    650   auto r2_padding_on_dim0_dim1 =
    651       CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
    652   Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
    653   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
    654                                              pad_value_instruction,
    655                                              r2_padding_on_dim0_dim1));
    656 
    657   m_->AddEntryComputation(b.Build());
    658 
    659   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    660 
    661   // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
    662   auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
    663   (*expected_array)(0, 0) = 7.0f;
    664   (*expected_array)(0, 1) = 2.718f;
    665   (*expected_array)(0, 2) = 2.718f;
    666   (*expected_array)(0, 3) = 2.718f;
    667   (*expected_array)(0, 4) = 2.718f;
    668   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
    669 
    670   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
    671 }
    672 
    673 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
    674   HloComputation::Builder b(TestName());
    675 
    676   // f32[4,3] {
    677   //  { 1, 2, 3 },
    678   //  { 5, 6, 7 },
    679   //  { 9, 10, 11 },
    680   //  { 13, 14, 15 },
    681   // }
    682   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
    683   input_array->FillUnique(1.0f);
    684   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
    685   HloInstruction* input_instruction =
    686       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
    687 
    688   auto pad_value_instruction = b.AddInstruction(
    689       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
    690 
    691   PaddingConfig padding_config = MakeNoPaddingConfig(2);
    692 
    693   // Negative padding that results in zero dimensions.
    694   auto r2_padding_on_dim0_dim1 =
    695       CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
    696 
    697   Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
    698   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
    699                                              pad_value_instruction,
    700                                              r2_padding_on_dim0_dim1));
    701 
    702   m_->AddEntryComputation(b.Build());
    703 
    704   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    705 
    706   auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
    707   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
    708 
    709   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    710 }
    711 
    712 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
    713   HloComputation::Builder b(TestName());
    714 
    715   // lhs:
    716   // f32[4,1] {
    717   //  { 1 },
    718   //  { 2 },
    719   //  { 3 },
    720   //  { 4 },
    721   // }
    722   auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
    723   lhs_array->FillUnique(1.0f);
    724   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
    725   HloInstruction* lhs_instruction =
    726       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    727 
    728   // rhs:
    729   // f32[2] { 1, 2 },
    730   auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
    731   HloInstruction* rhs_instruction =
    732       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    733 
    734   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
    735   DotDimensionNumbers dot_dnums;
    736   dot_dnums.add_lhs_contracting_dimensions(1);
    737   dot_dnums.add_rhs_contracting_dimensions(0);
    738   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
    739                                              rhs_instruction, dot_dnums,
    740                                              DefaultPrecisionConfig(2)));
    741   m_->AddEntryComputation(b.Build());
    742 
    743   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    744 
    745   // clang-format off
    746   auto expected_array = Array2D<float>({
    747       {1.f, 2.f},
    748       {2.f, 4.f},
    749       {3.f, 6.f},
    750       {4.f, 8.f},
    751   });
    752   // clang-format on
    753   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
    754 
    755   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    756 }
    757 
    758 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
    759   HloComputation::Builder b(TestName());
    760 
    761   // lhs:
    762   // f32[3]
    763   //  { 1, 2, 3 },
    764   auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
    765   HloInstruction* lhs_instruction =
    766       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    767 
    768   // rhs:
    769   // f32[3,2] {
    770   //  { 1, 2 },
    771   //  { 3, 4 },
    772   //  { 5, 6 },
    773   // }
    774   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
    775   rhs_array->FillUnique(1.0f);
    776   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
    777   HloInstruction* rhs_instruction =
    778       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    779 
    780   Shape shape = ShapeUtil::MakeShape(F32, {2});
    781   DotDimensionNumbers dot_dnums;
    782   dot_dnums.add_lhs_contracting_dimensions(0);
    783   dot_dnums.add_rhs_contracting_dimensions(0);
    784   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
    785                                              rhs_instruction, dot_dnums,
    786                                              DefaultPrecisionConfig(2)));
    787   m_->AddEntryComputation(b.Build());
    788 
    789   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    790 
    791   auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
    792 
    793   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    794 }
    795 
    796 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
    797   HloComputation::Builder b(TestName());
    798 
    799   // lhs:
    800   // f32[4,3] {
    801   //  { 1, 2, 3 },
    802   //  { 5, 6, 7 },
    803   //  { 9, 10, 11 },
    804   //  { 13, 14, 15 },
    805   // }
    806   auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
    807   lhs_array->FillUnique(1.0f);
    808   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
    809   HloInstruction* lhs_instruction =
    810       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    811 
    812   // rhs:
    813   // f32[3,2] {
    814   //  { 1, 2 },
    815   //  { 3, 4 },
    816   //  { 5, 6 },
    817   // }
    818   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
    819   rhs_array->FillUnique(1.0f);
    820   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
    821   HloInstruction* rhs_instruction =
    822       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    823 
    824   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
    825   DotDimensionNumbers dot_dnums;
    826   dot_dnums.add_lhs_contracting_dimensions(1);
    827   dot_dnums.add_rhs_contracting_dimensions(0);
    828   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
    829                                              rhs_instruction, dot_dnums,
    830                                              DefaultPrecisionConfig(2)));
    831   m_->AddEntryComputation(b.Build());
    832 
    833   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    834 
    835   auto expected_array = Array2D<float>({
    836       {22.f, 28.f},
    837       {58.f, 76.f},
    838       {94.f, 124.f},
    839       {130.f, 172.f},
    840   });
    841   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
    842 
    843   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    844 }
    845 
    846 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
    847   HloComputation::Builder b(TestName());
    848 
    849   auto lhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
    850   lhs_array->FillIota(1.0f);
    851   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
    852   HloInstruction* lhs_instruction =
    853       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    854 
    855   auto rhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
    856   rhs_array->FillIota(2.0f);
    857   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
    858   HloInstruction* rhs_instruction =
    859       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    860 
    861   Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
    862   DotDimensionNumbers dot_dnums;
    863 
    864   dot_dnums.add_lhs_batch_dimensions(0);
    865   dot_dnums.add_rhs_batch_dimensions(0);
    866   dot_dnums.add_lhs_contracting_dimensions(1);
    867   dot_dnums.add_lhs_contracting_dimensions(2);
    868   dot_dnums.add_rhs_contracting_dimensions(1);
    869   dot_dnums.add_rhs_contracting_dimensions(2);
    870   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
    871                                              rhs_instruction, dot_dnums,
    872                                              DefaultPrecisionConfig(2)));
    873   m_->AddEntryComputation(b.Build());
    874 
    875   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    876 
    877   float expected_1 = 0;
    878   for (float i = 1.0f; i < 7.0f; ++i) {
    879     expected_1 += i * i + i;
    880   }
    881   float expected_2 = 0;
    882   for (float i = 7.0f; i < 13.0f; ++i) {
    883     expected_2 += i * i + i;
    884   }
    885   auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
    886   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
    887 
    888   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    889 }
    890 
    891 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
    892   HloComputation::Builder b(TestName());
    893 
    894   Array3D<float> lhs_array = {{{1, 2, 3}}};
    895   auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
    896   HloInstruction* lhs_instruction =
    897       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    898 
    899   Array3D<float> rhs_array = {{{3.f, 4.f}}};
    900   auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
    901   HloInstruction* rhs_instruction =
    902       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    903 
    904   Window window;
    905   WindowDimension dim;
    906   dim.set_size(2);
    907   dim.set_stride(1);
    908   dim.set_padding_low(0);
    909   dim.set_padding_high(1);
    910   dim.set_window_dilation(1);
    911   dim.set_base_dilation(1);
    912   *window.add_dimensions() = dim;
    913 
    914   ConvolutionDimensionNumbers dnums;
    915   dnums.set_input_batch_dimension(0);
    916   dnums.set_output_batch_dimension(0);
    917   dnums.set_input_feature_dimension(1);
    918   dnums.set_output_feature_dimension(1);
    919   dnums.add_input_spatial_dimensions(2);
    920   dnums.add_output_spatial_dimensions(2);
    921 
    922   dnums.set_kernel_output_feature_dimension(0);
    923   dnums.set_kernel_input_feature_dimension(1);
    924   dnums.add_kernel_spatial_dimensions(2);
    925 
    926   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
    927   b.AddInstruction(HloInstruction::CreateConvolve(
    928       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
    929       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
    930   m_->AddEntryComputation(b.Build());
    931 
    932   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    933 
    934   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
    935   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
    936 
    937   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
    938 }
    939 
    940 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
    941   HloComputation::Builder b(TestName());
    942 
    943   Array4D<float> lhs_array(1, 1, 4, 4);
    944   // clang-format off
    945   lhs_array.FillWithYX(Array2D<float>({
    946     {1,  2,  3,  4 },
    947     {5,  6,  7,  8 },
    948     {9,  10, 11, 12},
    949     {13, 14, 15, 16},
    950   }));
    951   // clang-format on
    952   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
    953   HloInstruction* lhs_instruction =
    954       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    955 
    956   Array4D<float> rhs_array(1, 1, 2, 2);
    957   // clang-format off
    958   rhs_array.FillWithYX(Array2D<float>({
    959     {5, 6},
    960     {7, 8},
    961   }));
    962   // clang-format on
    963   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
    964   HloInstruction* rhs_instruction =
    965       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    966 
    967   Window window;
    968   WindowDimension dim;
    969   dim.set_size(2);
    970   dim.set_stride(1);
    971   dim.set_padding_low(0);
    972   dim.set_padding_high(1);
    973   dim.set_window_dilation(1);
    974   dim.set_base_dilation(1);
    975   *window.add_dimensions() = dim;
    976   *window.add_dimensions() = dim;
    977 
    978   ConvolutionDimensionNumbers dnums =
    979       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
    980 
    981   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
    982   b.AddInstruction(HloInstruction::CreateConvolve(
    983       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
    984       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
    985   m_->AddEntryComputation(b.Build());
    986 
    987   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
    988 
    989   Array4D<float> expected_array(1, 1, 4, 4);
    990   // clang-format off
    991   expected_array.FillWithYX(Array2D<float>({
    992     {100, 126, 152,  76},
    993     {204, 230, 256, 124},
    994     {308, 334, 360, 172},
    995     {149, 160, 171,  80},
    996   }));
    997   // clang-format on
    998   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
    999 
   1000   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1001 }
   1002 
   1003 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
   1004   HloComputation::Builder b(TestName());
   1005 
   1006   // clang-format off
   1007   // Input dimensions: [feature=2, height=3, batch=1, width=4]
   1008   Array4D<float> input({
   1009     {{{1, 2, 3, 4}},
   1010      {{5, 6, 7, 8}},
   1011      {{9, 10, 11, 12}}},
   1012     {{{13, 14, 15, 16}},
   1013      {{17, 18, 19, 20}},
   1014      {{21, 22, 23, 24}}}
   1015   });
   1016   // Weight dimensions:
   1017   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
   1018   Array4D<float> weight({{
   1019     {{1, 7, 13},
   1020      {4, 10, 16}},
   1021     {{2, 8, 14},
   1022      {5, 11, 17}},
   1023     {{3, 9, 15},
   1024      {6, 12, 18}}
   1025   }});
   1026   // clang-format on
   1027 
   1028   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
   1029   HloInstruction* lhs_instruction =
   1030       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   1031 
   1032   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
   1033   HloInstruction* rhs_instruction =
   1034       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
   1035   rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
   1036       rhs_instruction->shape(), rhs_instruction, {3, 1}));
   1037 
   1038   Window window;
   1039   WindowDimension dim;
   1040   dim.set_size(3);
   1041   dim.set_stride(1);
   1042   dim.set_padding_low(0);
   1043   dim.set_padding_high(0);
   1044   dim.set_window_dilation(1);
   1045   dim.set_base_dilation(1);
   1046   dim.set_window_reversal(true);
   1047   *window.add_dimensions() = dim;
   1048   *window.add_dimensions() = dim;
   1049 
   1050   ConvolutionDimensionNumbers dnums;
   1051   dnums.set_input_batch_dimension(2);
   1052   dnums.set_output_batch_dimension(2);
   1053   dnums.set_input_feature_dimension(0);
   1054   dnums.set_output_feature_dimension(0);
   1055   dnums.add_input_spatial_dimensions(1);
   1056   dnums.add_output_spatial_dimensions(1);
   1057   dnums.add_input_spatial_dimensions(3);
   1058   dnums.add_output_spatial_dimensions(3);
   1059 
   1060   dnums.set_kernel_output_feature_dimension(0);
   1061   dnums.set_kernel_input_feature_dimension(2);
   1062   dnums.add_kernel_spatial_dimensions(3);
   1063   dnums.add_kernel_spatial_dimensions(1);
   1064 
   1065   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
   1066   b.AddInstruction(HloInstruction::CreateConvolve(
   1067       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
   1068       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
   1069   m_->AddEntryComputation(b.Build());
   1070 
   1071   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1072 
   1073   // clang-format off
   1074   // Result dimensions: [feature=1, height=1, batch=1, width=2]
   1075   Array4D<float> expected_array({{{{2514, 2685}}}});
   1076   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
   1077   // clang-format on
   1078   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
   1079       use_bfloat16_ ? expected_array_bf16 : expected_array);
   1080 
   1081   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1082 }
   1083 
   1084 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
   1085   HloComputation::Builder b(TestName());
   1086 
   1087   // clang-format off
   1088   // Input dimensions: [feature=2, height=3, batch=1, width=4]
   1089   Array4D<float> input({
   1090     {{{1, 2, 3, 4}},
   1091      {{5, 6, 7, 8}},
   1092      {{9, 10, 11, 12}}},
   1093     {{{13, 14, 15, 16}},
   1094      {{17, 18, 19, 20}},
   1095      {{21, 22, 23, 24}}}
   1096   });
   1097   // Weight dimensions:
   1098   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
   1099   Array4D<float> weight({{
   1100     {{1, 7, 13},
   1101      {4, 10, 16}},
   1102     {{2, 8, 14},
   1103      {5, 11, 17}},
   1104     {{3, 9, 15},
   1105      {6, 12, 18}}
   1106   }});
   1107   // clang-format on
   1108 
   1109   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
   1110   HloInstruction* lhs_instruction =
   1111       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   1112 
   1113   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
   1114   HloInstruction* rhs_instruction =
   1115       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
   1116 
   1117   Window window;
   1118   WindowDimension dim;
   1119   dim.set_size(3);
   1120   dim.set_stride(1);
   1121   dim.set_padding_low(0);
   1122   dim.set_padding_high(0);
   1123   dim.set_window_dilation(1);
   1124   dim.set_base_dilation(1);
   1125   *window.add_dimensions() = dim;
   1126   *window.add_dimensions() = dim;
   1127 
   1128   ConvolutionDimensionNumbers dnums;
   1129   dnums.set_input_batch_dimension(2);
   1130   dnums.set_output_batch_dimension(2);
   1131   dnums.set_input_feature_dimension(0);
   1132   dnums.set_output_feature_dimension(0);
   1133   dnums.add_input_spatial_dimensions(1);
   1134   dnums.add_output_spatial_dimensions(1);
   1135   dnums.add_input_spatial_dimensions(3);
   1136   dnums.add_output_spatial_dimensions(3);
   1137 
   1138   dnums.set_kernel_output_feature_dimension(0);
   1139   dnums.set_kernel_input_feature_dimension(2);
   1140   dnums.add_kernel_spatial_dimensions(3);
   1141   dnums.add_kernel_spatial_dimensions(1);
   1142 
   1143   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
   1144   b.AddInstruction(HloInstruction::CreateConvolve(
   1145       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
   1146       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
   1147   m_->AddEntryComputation(b.Build());
   1148 
   1149   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1150 
   1151   // clang-format off
   1152   // Result dimensions: [feature=1, height=1, batch=1, width=2]
   1153   Array4D<float> expected_array({{{{2514, 2685}}}});
   1154   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
   1155   // clang-format on
   1156   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
   1157       use_bfloat16_ ? expected_array_bf16 : expected_array);
   1158 
   1159   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1160 }
   1161 
   1162 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
   1163   HloComputation::Builder b(TestName());
   1164 
   1165   Array4D<float> lhs_array(1, 1, 4, 4);
   1166   // clang-format off
   1167   lhs_array.FillWithYX(Array2D<float>({
   1168     {1,  2,  3,  4 },
   1169     {5,  6,  7,  8 },
   1170     {9,  10, 11, 12},
   1171     {13, 14, 15, 16},
   1172   }));
   1173   // clang-format on
   1174   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
   1175   HloInstruction* lhs_instruction =
   1176       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   1177 
   1178   Array4D<float> rhs_array(1, 1, 2, 2);
   1179   // clang-format off
   1180   rhs_array.FillWithYX(Array2D<float>({
   1181     {5, 6},
   1182     {7, 8},
   1183   }));
   1184   // clang-format on
   1185   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
   1186   HloInstruction* rhs_instruction =
   1187       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
   1188 
   1189   Window window;
   1190   WindowDimension dim;
   1191   dim.set_size(2);
   1192   dim.set_stride(1);
   1193   dim.set_padding_low(0);
   1194   dim.set_padding_high(1);
   1195   dim.set_window_dilation(1);
   1196   dim.set_base_dilation(2);
   1197   *window.add_dimensions() = dim;
   1198   *window.add_dimensions() = dim;
   1199 
   1200   ConvolutionDimensionNumbers dnums =
   1201       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
   1202 
   1203   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
   1204   b.AddInstruction(HloInstruction::CreateConvolve(
   1205       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
   1206       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
   1207   m_->AddEntryComputation(b.Build());
   1208 
   1209   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1210 
   1211   Array4D<float> expected_array(1, 1, 7, 7);
   1212   expected_array.FillWithYX(Array2D<float>({
   1213       {5, 12, 10, 18, 15, 24, 20},
   1214       {35, 48, 42, 56, 49, 64, 56},
   1215       {25, 36, 30, 42, 35, 48, 40},
   1216       {63, 80, 70, 88, 77, 96, 84},
   1217       {45, 60, 50, 66, 55, 72, 60},
   1218       {91, 112, 98, 120, 105, 128, 112},
   1219       {65, 84, 70, 90, 75, 96, 80},
   1220   }));
   1221   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
   1222 
   1223   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1224 }
   1225 
   1226 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
   1227   HloComputation::Builder b(TestName());
   1228 
   1229   Array4D<float> lhs_array(1, 1, 4, 4);
   1230   // clang-format off
   1231   lhs_array.FillWithYX(Array2D<float>({
   1232     {1,  2,  3,  4 },
   1233     {5,  6,  7,  8 },
   1234     {9,  10, 11, 12},
   1235     {13, 14, 15, 16},
   1236   }));
   1237   // clang-format on
   1238   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
   1239   HloInstruction* lhs_instruction =
   1240       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   1241 
   1242   Array4D<float> rhs_array(1, 1, 2, 2);
   1243   // clang-format off
   1244   rhs_array.FillWithYX(Array2D<float>({
   1245     {5, 6},
   1246     {7, 8},
   1247   }));
   1248   // clang-format on
   1249   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
   1250   HloInstruction* rhs_instruction =
   1251       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
   1252 
   1253   Window window;
   1254   WindowDimension dim;
   1255   dim.set_size(2);
   1256   dim.set_stride(1);
   1257   dim.set_padding_low(1);
   1258   dim.set_padding_high(1);
   1259   dim.set_window_dilation(1);
   1260   dim.set_base_dilation(2);
   1261   *window.add_dimensions() = dim;
   1262   *window.add_dimensions() = dim;
   1263 
   1264   ConvolutionDimensionNumbers dnums =
   1265       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
   1266 
   1267   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
   1268   b.AddInstruction(HloInstruction::CreateConvolve(
   1269       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
   1270       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
   1271   m_->AddEntryComputation(b.Build());
   1272 
   1273   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1274 
   1275   Array4D<float> expected_array(1, 1, 8, 8);
   1276   expected_array.FillWithYX(Array2D<float>({
   1277       {8, 7, 16, 14, 24, 21, 32, 28},
   1278       {6, 5, 12, 10, 18, 15, 24, 20},
   1279       {40, 35, 48, 42, 56, 49, 64, 56},
   1280       {30, 25, 36, 30, 42, 35, 48, 40},
   1281       {72, 63, 80, 70, 88, 77, 96, 84},
   1282       {54, 45, 60, 50, 66, 55, 72, 60},
   1283       {104, 91, 112, 98, 120, 105, 128, 112},
   1284       {78, 65, 84, 70, 90, 75, 96, 80},
   1285   }));
   1286   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
   1287 
   1288   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1289 }
   1290 
   1291 TEST_P(HloEvaluatorBf16Test,
   1292        DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
   1293   HloComputation::Builder b(TestName());
   1294 
   1295   Array4D<float> lhs_array(1, 1, 4, 4);
   1296   // clang-format off
   1297   lhs_array.FillWithYX(Array2D<float>({
   1298     {1,  2,  3,  4 },
   1299     {5,  6,  7,  8 },
   1300     {9,  10, 11, 12},
   1301     {13, 14, 15, 16},
   1302   }));
   1303   // clang-format on
   1304   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
   1305   HloInstruction* lhs_instruction =
   1306       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   1307 
   1308   Array4D<float> rhs_array(1, 1, 2, 3);
   1309   // clang-format off
   1310   rhs_array.FillWithYX(Array2D<float>({
   1311     {5, 6, 7},
   1312     {8, 9, 10},
   1313   }));
   1314   // clang-format on
   1315   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
   1316   HloInstruction* rhs_instruction =
   1317       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
   1318 
   1319   Window window;
   1320   WindowDimension dim;
   1321   dim.set_size(2);
   1322   dim.set_stride(1);
   1323   dim.set_padding_low(2);
   1324   dim.set_padding_high(2);
   1325   dim.set_window_dilation(2);
   1326   dim.set_base_dilation(2);
   1327   *window.add_dimensions() = dim;
   1328   dim.set_size(3);
   1329   dim.set_stride(3);
   1330   dim.set_padding_low(2);
   1331   dim.set_padding_high(-1);
   1332   dim.set_window_dilation(1);
   1333   dim.set_base_dilation(3);
   1334   *window.add_dimensions() = dim;
   1335 
   1336   ConvolutionDimensionNumbers dnums =
   1337       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
   1338 
   1339   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
   1340   b.AddInstruction(HloInstruction::CreateConvolve(
   1341       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
   1342       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
   1343   m_->AddEntryComputation(b.Build());
   1344 
   1345   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1346 
   1347   Array4D<float> expected_array(1, 1, 9, 3);
   1348   expected_array.FillWithYX(Array2D<float>({
   1349       {10, 20, 30},
   1350       {0, 0, 0},
   1351       {57, 74, 91},
   1352       {0, 0, 0},
   1353       {125, 142, 159},
   1354       {0, 0, 0},
   1355       {193, 210, 227},
   1356       {0, 0, 0},
   1357       {91, 98, 105},
   1358   }));
   1359   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
   1360 
   1361   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1362 }
   1363 
   1364 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
   1365   HloComputation::Builder b(TestName());
   1366   std::vector<int64> input_dims = {1, 2, 2, 4};
   1367   std::vector<int64> filter_dims = {2, 2, 2, 8};
   1368   Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
   1369   Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
   1370   // Tensorflow dimension numbers for 2D convolution.
   1371   ConvolutionDimensionNumbers dnums;
   1372   dnums.set_input_batch_dimension(0);
   1373   dnums.set_output_batch_dimension(0);
   1374   dnums.add_input_spatial_dimensions(1);
   1375   dnums.add_output_spatial_dimensions(1);
   1376   dnums.add_input_spatial_dimensions(2);
   1377   dnums.add_output_spatial_dimensions(2);
   1378   dnums.set_input_feature_dimension(3);
   1379   dnums.set_output_feature_dimension(3);
   1380   dnums.add_kernel_spatial_dimensions(0);
   1381   dnums.add_kernel_spatial_dimensions(1);
   1382   dnums.set_kernel_input_feature_dimension(2);
   1383   dnums.set_kernel_output_feature_dimension(3);
   1384 
   1385   Window window;
   1386   WindowDimension dim;
   1387   dim.set_size(2);
   1388   dim.set_stride(1);
   1389   dim.set_padding_low(0);
   1390   dim.set_padding_high(0);
   1391   dim.set_window_dilation(1);
   1392   dim.set_base_dilation(1);
   1393   *window.add_dimensions() = dim;
   1394   *window.add_dimensions() = dim;
   1395 
   1396   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
   1397   std::iota(input_elems.begin(), input_elems.end(), -7);
   1398   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
   1399   auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
   1400   HloInstruction* lhs_instruction =
   1401       b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
   1402 
   1403   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
   1404   std::iota(filter_elems.begin(), filter_elems.end(), -31);
   1405   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
   1406   auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
   1407   HloInstruction* rhs_instruction =
   1408       b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
   1409 
   1410   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
   1411   b.AddInstruction(HloInstruction::CreateConvolve(
   1412       shape, lhs_instruction, rhs_instruction,
   1413       /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
   1414       DefaultPrecisionConfig(2)));
   1415   m_->AddEntryComputation(b.Build());
   1416 
   1417   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1418 
   1419   Array4D<float> expected_array(1, 1, 1, 8);
   1420   expected_array.FillWithYX(
   1421       Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
   1422   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
   1423   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1424 }
   1425 
   1426 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
   1427 
   1428 // Tests that Reduce doesn't lose precision when adding many numbers (because
   1429 // it accumulates its result in a double).
   1430 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
   1431   auto m = CreateNewVerifiedModule();
   1432   HloComputation::Builder b(TestName());
   1433 
   1434   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
   1435   std::vector<float> v(kNumElements, 1.0f);
   1436   HloInstruction* arg_instruction = b.AddInstruction(
   1437       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
   1438   HloInstruction* init_value = b.AddInstruction(
   1439       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1440 
   1441   HloComputation::Builder add_computation("add");
   1442   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1443   auto param_lhs = add_computation.AddInstruction(
   1444       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1445   auto param_rhs = add_computation.AddInstruction(
   1446       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1447   add_computation.AddInstruction(HloInstruction::CreateBinary(
   1448       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
   1449   auto add_func = m->AddEmbeddedComputation(add_computation.Build());
   1450 
   1451   HloInstruction* reduce_instruction = b.AddInstruction(
   1452       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
   1453                                    /*dimensions_to_reduce=*/{0}, add_func));
   1454   m->AddEntryComputation(b.Build());
   1455 
   1456   HloEvaluator hlo_eval;
   1457   Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
   1458   LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
   1459 }
   1460 
   1461 // Reducing many numbers should be fast because it doesn't create
   1462 // intermediate Literals; the microbenchmark should finish in < 1 msec.
   1463 void BM_ReducePrecisely(int num_iters) {
   1464   tensorflow::testing::StopTiming();
   1465   HloComputation::Builder b("BM_ReducePrecisely");
   1466   HloModuleConfig config;
   1467   config.set_debug_options(GetDebugOptionsFromFlags());
   1468   HloModule module("BM_ReducePrecisely", config);
   1469 
   1470   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
   1471   std::vector<float> v(kNumElements, 1.0f);
   1472   HloInstruction* arg_instruction = b.AddInstruction(
   1473       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
   1474   auto init_value = b.AddInstruction(
   1475       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1476 
   1477   HloComputation::Builder add_computation("add");
   1478   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1479   auto param_lhs = add_computation.AddInstruction(
   1480       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1481   auto param_rhs = add_computation.AddInstruction(
   1482       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1483   add_computation.AddInstruction(HloInstruction::CreateBinary(
   1484       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
   1485   auto add_func = module.AddEmbeddedComputation(add_computation.Build());
   1486 
   1487   HloInstruction* reduce_instruction = b.AddInstruction(
   1488       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
   1489                                    /*dimensions_to_reduce=*/{0}, add_func));
   1490   module.AddEntryComputation(b.Build());
   1491 
   1492   HloEvaluator hlo_eval;
   1493   tensorflow::testing::StartTiming();
   1494   hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
   1495   tensorflow::testing::StopTiming();
   1496 }
   1497 
   1498 BENCHMARK(BM_ReducePrecisely);
   1499 
   1500 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
   1501   HloComputation::Builder b(TestName());
   1502 
   1503   // arg:
   1504   // f32[2,3] {
   1505   //  { 1, 2, 3 },
   1506   //  { 5, 6, 7 },
   1507   // }
   1508   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
   1509   arg_array->FillUnique(1.0f);
   1510   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
   1511 
   1512   HloInstruction* arg_instruction =
   1513       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
   1514 
   1515   auto init_value = b.AddInstruction(
   1516       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1517 
   1518   HloComputation::Builder add_computation("add");
   1519   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1520   auto param_lhs = add_computation.AddInstruction(
   1521       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1522   auto param_rhs = add_computation.AddInstruction(
   1523       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1524   add_computation.AddInstruction(HloInstruction::CreateBinary(
   1525       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
   1526   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
   1527 
   1528   Shape shape = ShapeUtil::MakeShape(F32, {2});
   1529   b.AddInstruction(
   1530       HloInstruction::CreateReduce(shape, arg_instruction, init_value,
   1531                                    /*dimensions_to_reduce=*/{1}, add_func));
   1532 
   1533   m_->AddEntryComputation(b.Build());
   1534 
   1535   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1536 
   1537   auto expected = LiteralUtil::CreateR1<float>({6, 18});
   1538 
   1539   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1540 }
   1541 
   1542 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
   1543   HloComputation::Builder b(TestName());
   1544 
   1545   // arg:
   1546   // f32[2,3] {
   1547   //  { 1, 2, 3 },
   1548   //  { 5, 6, 7 },
   1549   // }
   1550   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
   1551   arg_array->FillUnique(1.0f);
   1552   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
   1553 
   1554   HloInstruction* arg_instruction =
   1555       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
   1556 
   1557   auto init_value = b.AddInstruction(
   1558       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1559 
   1560   HloComputation::Builder max_computation("max");
   1561   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1562   auto param_lhs = max_computation.AddInstruction(
   1563       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1564   auto param_rhs = max_computation.AddInstruction(
   1565       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1566   max_computation.AddInstruction(HloInstruction::CreateBinary(
   1567       scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
   1568   auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
   1569 
   1570   Window window;
   1571   WindowDimension dim;
   1572   dim.set_size(2);
   1573   dim.set_stride(1);
   1574   dim.set_padding_low(0);
   1575   dim.set_padding_high(0);
   1576   dim.set_window_dilation(1);
   1577   dim.set_base_dilation(1);
   1578   *window.add_dimensions() = dim;
   1579   *window.add_dimensions() = dim;
   1580 
   1581   Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
   1582   b.AddInstruction(HloInstruction::CreateReduceWindow(
   1583       shape, arg_instruction, init_value, window, max_func));
   1584 
   1585   m_->AddEntryComputation(b.Build());
   1586 
   1587   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1588 
   1589   auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
   1590   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1591 }
   1592 
   1593 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) {
   1594   HloComputation::Builder b(TestName());
   1595 
   1596   // arg:
   1597   // f32[3,3] {
   1598   //  { 1, 2, 3 },
   1599   //  { 5, 6, 7 },
   1600   //  { 9, 10, 11 },
   1601   // }
   1602   auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
   1603   arg_array->FillUnique(1.0f);
   1604   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
   1605 
   1606   HloInstruction* arg_instruction =
   1607       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
   1608 
   1609   auto init_value = b.AddInstruction(
   1610       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1611 
   1612   HloComputation::Builder max_computation("max");
   1613   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1614   auto param_lhs = max_computation.AddInstruction(
   1615       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1616   auto param_rhs = max_computation.AddInstruction(
   1617       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1618   max_computation.AddInstruction(HloInstruction::CreateBinary(
   1619       scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
   1620   auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
   1621 
   1622   Window window;
   1623   WindowDimension dim;
   1624   dim.set_size(2);
   1625   dim.set_stride(1);
   1626   dim.set_padding_low(0);
   1627   dim.set_padding_high(0);
   1628   dim.set_window_dilation(2);
   1629   dim.set_base_dilation(1);
   1630   *window.add_dimensions() = dim;
   1631   *window.add_dimensions() = dim;
   1632 
   1633   Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
   1634   b.AddInstruction(HloInstruction::CreateReduceWindow(
   1635       shape, arg_instruction, init_value, window, max_func));
   1636 
   1637   m_->AddEntryComputation(b.Build());
   1638 
   1639   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1640 
   1641   auto expected = LiteralUtil::CreateR2<float>({{11}});
   1642   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1643 }
   1644 
   1645 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
   1646   HloComputation::Builder b(TestName());
   1647 
   1648   // arg:
   1649   // f32[2,3] {
   1650   //  { 1, 2, 3 },
   1651   //  { 5, 6, 7 },
   1652   // }
   1653   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
   1654   arg_array->FillUnique(1.0f);
   1655   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
   1656 
   1657   HloInstruction* arg_instruction =
   1658       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
   1659 
   1660   auto init_value = b.AddInstruction(
   1661       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1662 
   1663   HloComputation::Builder add_computation("add");
   1664   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1665   auto param_lhs = add_computation.AddInstruction(
   1666       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1667   auto param_rhs = add_computation.AddInstruction(
   1668       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1669   add_computation.AddInstruction(HloInstruction::CreateBinary(
   1670       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
   1671   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
   1672 
   1673   Window window;
   1674   WindowDimension dim;
   1675   dim.set_size(1);
   1676   dim.set_stride(1);
   1677   dim.set_padding_low(0);
   1678   dim.set_padding_high(0);
   1679   dim.set_window_dilation(1);
   1680   dim.set_base_dilation(1);
   1681   *window.add_dimensions() = dim;
   1682   dim.set_size(2);
   1683   dim.set_stride(1);
   1684   dim.set_padding_low(1);
   1685   dim.set_padding_high(0);
   1686   dim.set_window_dilation(1);
   1687   dim.set_base_dilation(1);
   1688   *window.add_dimensions() = dim;
   1689 
   1690   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
   1691   b.AddInstruction(HloInstruction::CreateReduceWindow(
   1692       shape, arg_instruction, init_value, window, add_func));
   1693 
   1694   m_->AddEntryComputation(b.Build());
   1695 
   1696   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1697 
   1698   auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
   1699   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1700 }
   1701 
   1702 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
   1703   HloComputation::Builder b(TestName());
   1704 
   1705   // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
   1706   std::vector<int64> input_dims(6, 4);
   1707   Literal arg_literal =
   1708       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
   1709 
   1710   HloInstruction* arg_instruction =
   1711       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
   1712 
   1713   auto init_value = b.AddInstruction(
   1714       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
   1715 
   1716   HloComputation::Builder add_computation("add");
   1717   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   1718   auto param_lhs = add_computation.AddInstruction(
   1719       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
   1720   auto param_rhs = add_computation.AddInstruction(
   1721       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
   1722   add_computation.AddInstruction(HloInstruction::CreateBinary(
   1723       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
   1724   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
   1725 
   1726   Window window;
   1727 
   1728   WindowDimension trivial_dim;
   1729   trivial_dim.set_size(1);
   1730   trivial_dim.set_stride(1);
   1731   trivial_dim.set_padding_low(0);
   1732   trivial_dim.set_padding_high(0);
   1733   trivial_dim.set_window_dilation(1);
   1734   trivial_dim.set_base_dilation(1);
   1735 
   1736   WindowDimension active_dim;
   1737   active_dim.set_size(2);
   1738   active_dim.set_stride(1);
   1739   active_dim.set_padding_low(0);
   1740   active_dim.set_padding_high(0);
   1741   active_dim.set_window_dilation(1);
   1742   active_dim.set_base_dilation(1);
   1743 
   1744   *window.add_dimensions() = trivial_dim;
   1745   *window.add_dimensions() = active_dim;
   1746   *window.add_dimensions() = active_dim;
   1747   *window.add_dimensions() = active_dim;
   1748   *window.add_dimensions() = trivial_dim;
   1749   *window.add_dimensions() = trivial_dim;
   1750 
   1751   Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
   1752   b.AddInstruction(HloInstruction::CreateReduceWindow(
   1753       shape, arg_instruction, init_value, window, add_func));
   1754 
   1755   m_->AddEntryComputation(b.Build());
   1756 
   1757   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1758 
   1759   std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
   1760   Literal result_literal =
   1761       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
   1762   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
   1763 }
   1764 
   1765 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
   1766   HloComputation::Builder b(TestName());
   1767 
   1768   // arg:
   1769   // f32[3,5] {
   1770   //  { 1, 2, 3, 4, 5 },
   1771   //  { 9, 10, 11, 12, 13 },
   1772   //  { 17, 18, 19, 20, 21 },
   1773   // }
   1774   auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
   1775   operand_array->FillUnique(1.0f);
   1776   auto operand_literal =
   1777       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
   1778 
   1779   HloInstruction* operand = b.AddInstruction(
   1780       HloInstruction::CreateConstant(std::move(operand_literal)));
   1781 
   1782   Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
   1783   b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
   1784                                                /*start_indices=*/{0, 2},
   1785                                                /*limit_indices=*/{3, 5},
   1786                                                /*strides=*/{2, 3}));
   1787   m_->AddEntryComputation(b.Build());
   1788 
   1789   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1790 
   1791   auto expected = LiteralUtil::CreateR2<float>({
   1792       {3},
   1793       {19},
   1794   });
   1795 
   1796   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1797 }
   1798 
   1799 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
   1800   HloComputation::Builder b(TestName());
   1801 
   1802   // arg:
   1803   // f32[2,4] {
   1804   //  { 1, 2, 3, 4 },
   1805   //  { 5, 6, 7, 8 },
   1806   // }
   1807   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
   1808   operand_array->FillUnique(1.0f);
   1809   auto operand_literal =
   1810       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
   1811 
   1812   HloInstruction* operand = b.AddInstruction(
   1813       HloInstruction::CreateConstant(std::move(operand_literal)));
   1814 
   1815   auto zero = b.AddInstruction(
   1816       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
   1817   auto one = b.AddInstruction(
   1818       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
   1819 
   1820   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
   1821   b.AddInstruction(
   1822       HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
   1823   m_->AddEntryComputation(b.Build());
   1824 
   1825   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1826 
   1827   auto expected = LiteralUtil::CreateR2<float>({
   1828       {2, 3, 4},
   1829       {6, 7, 8},
   1830   });
   1831 
   1832   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1833 }
   1834 
   1835 // Verifies that the HloEvaluator's implementation goes along with existing
   1836 // backends' behavior, although this is not required by the spec.
   1837 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
   1838   HloComputation::Builder b(TestName());
   1839 
   1840   // arg:
   1841   // f32[2,4] {
   1842   //  { 1, 2, 3, 4 },
   1843   //  { 5, 6, 7, 8 },
   1844   // }
   1845   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
   1846   operand_array->FillUnique(1.0f);
   1847   auto operand_literal =
   1848       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
   1849 
   1850   HloInstruction* operand = b.AddInstruction(
   1851       HloInstruction::CreateConstant(std::move(operand_literal)));
   1852 
   1853   auto two = b.AddInstruction(
   1854       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
   1855   auto one = b.AddInstruction(
   1856       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
   1857 
   1858   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
   1859   b.AddInstruction(
   1860       HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
   1861   m_->AddEntryComputation(b.Build());
   1862 
   1863   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1864 
   1865   auto expected = LiteralUtil::CreateR2<float>({
   1866       {2, 3, 4},
   1867       {6, 7, 8},
   1868   });
   1869 
   1870   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1871 }
   1872 
   1873 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
   1874   HloComputation::Builder b(TestName());
   1875 
   1876   // arg:
   1877   // f32[2,3] {
   1878   //  { 1, 2, 3 },
   1879   //  { 5, 6, 7 },
   1880   // }
   1881   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
   1882   operand_array->FillUnique(1.0);
   1883   auto operand_literal =
   1884       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
   1885 
   1886   HloInstruction* operand = b.AddInstruction(
   1887       HloInstruction::CreateConstant(std::move(operand_literal)));
   1888 
   1889   auto zero = b.AddInstruction(
   1890       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
   1891   auto one = b.AddInstruction(
   1892       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
   1893 
   1894   auto update = b.AddInstruction(HloInstruction::CreateConstant(
   1895       LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
   1896 
   1897   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
   1898   b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
   1899       shape, operand, update, {zero, one}));
   1900   m_->AddEntryComputation(b.Build());
   1901 
   1902   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1903 
   1904   auto expected = LiteralUtil::CreateR2<double>({
   1905       {1, -2, -3},
   1906       {5, -6, -7},
   1907   });
   1908 
   1909   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1910 }
   1911 
   1912 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
   1913   HloComputation::Builder b(TestName());
   1914 
   1915   // arg:
   1916   // f32[2,3] {
   1917   //  { 1, 2, 3 },
   1918   //  { 5, 6, 7 },
   1919   // }
   1920   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
   1921   operand_array->FillUnique(1.0);
   1922   auto operand_literal2 =
   1923       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
   1924 
   1925   HloInstruction* operand2 = b.AddInstruction(
   1926       HloInstruction::CreateConstant(std::move(operand_literal2)));
   1927   HloInstruction* operand1 = b.AddInstruction(
   1928       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
   1929 
   1930   auto tuple =
   1931       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
   1932 
   1933   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
   1934   b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
   1935 
   1936   m_->AddEntryComputation(b.Build());
   1937 
   1938   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1939 
   1940   auto expected = LiteralUtil::CreateR2<double>({
   1941       {1, 2, 3},
   1942       {5, 6, 7},
   1943   });
   1944 
   1945   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1946 }
   1947 
   1948 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
   1949   HloComputation::Builder b(TestName());
   1950 
   1951   // arg:
   1952   // f32[2,3] {
   1953   //  { 1, 2, 3 },
   1954   //  { 5, 6, 7 },
   1955   // }
   1956   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
   1957   operand_array->FillUnique(1.0);
   1958 
   1959   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
   1960       LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
   1961   HloInstruction* operand1 = b.AddInstruction(
   1962       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
   1963 
   1964   auto tuple1 =
   1965       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
   1966   auto tuple2 =
   1967       b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
   1968 
   1969   auto outer_tuple =
   1970       b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
   1971 
   1972   b.AddInstruction(
   1973       HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
   1974 
   1975   m_->AddEntryComputation(b.Build());
   1976 
   1977   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   1978 
   1979   auto result_inner_literal =
   1980       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
   1981   auto expected =
   1982       LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
   1983 
   1984   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   1985 }
   1986 
   1987 TEST_P(HloEvaluatorBf16Test, Reverse) {
   1988   HloComputation::Builder b(TestName());
   1989 
   1990   // Input shape is float[4x3x2x1].
   1991   // clang-format off
   1992   Array4D<float> input({
   1993     {{{1.0f}, {2.0f}},
   1994      {{3.0f}, {4.0f}},
   1995      {{5.0f}, {6.0f}}},
   1996     {{{7.0f}, {8.0f}},
   1997      {{9.0f}, {10.0f}},
   1998      {{11.0f}, {12.0f}}},
   1999     {{{13.0f}, {14.0f}},
   2000      {{15.0f}, {16.0f}},
   2001      {{17.0f}, {18.0f}}},
   2002     {{{19.0f}, {20.0f}},
   2003      {{21.0f}, {22.0f}},
   2004      {{23.0f}, {24.0f}}},
   2005   });
   2006   // clang-format on
   2007   auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
   2008   HloInstruction* operand = b.AddInstruction(
   2009       HloInstruction::CreateConstant(std::move(operand_literal)));
   2010 
   2011   const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
   2012   b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
   2013   m_->AddEntryComputation(b.Build());
   2014 
   2015   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   2016 
   2017   // clang-format off
   2018   auto expected = LiteralUtil::CreateR4FromArray4D<float>({
   2019     {{{23.0f}, {24.0f}},
   2020      {{21.0f}, {22.0f}},
   2021      {{19.0f}, {20.0f}}},
   2022 
   2023     {{{17.0f}, {18.0f}},
   2024      {{15.0f}, {16.0f}},
   2025      {{13.0f}, {14.0f}}},
   2026 
   2027     {{{11.0f}, {12.0f}},
   2028      {{9.0f}, {10.0f}},
   2029      {{7.0f}, {8.0f}}},
   2030 
   2031     {{{5.0f}, {6.0f}},
   2032      {{3.0f}, {4.0f}},
   2033      {{1.0f}, {2.0f}}},
   2034   });
   2035   // clang-format on
   2036 
   2037   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2038 }
   2039 
   2040 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
   2041   HloComputation::Builder b(TestName());
   2042   Shape shape = ShapeUtil::MakeShape(F32, {4});
   2043 
   2044   HloInstruction* param0 =
   2045       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
   2046   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
   2047       shape, HloOpcode::kMultiply, param0, param0));
   2048   HloInstruction* add = b.AddInstruction(
   2049       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
   2050 
   2051   // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
   2052   HloEvaluator evaluator;
   2053   Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
   2054   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
   2055   TF_ASSERT_OK_AND_ASSIGN(
   2056       Literal result,
   2057       evaluator.EvaluateWithSubstitutions(
   2058           add, {{param0, &param0_literal}, {square, &square_literal}}));
   2059   EXPECT_TRUE(LiteralTestUtil::Equal(
   2060       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
   2061 }
   2062 
   2063 // Check that EvaluateWithSubstitutions works if one of the operands to the op
   2064 // we're evaluating is a constant.
   2065 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
   2066   HloComputation::Builder b(TestName());
   2067   Shape shape = ShapeUtil::MakeShape(F32, {4});
   2068 
   2069   HloInstruction* param0 =
   2070       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
   2071   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
   2072       shape, HloOpcode::kMultiply, param0, param0));
   2073   HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
   2074       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
   2075   HloInstruction* add = b.AddInstruction(
   2076       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
   2077 
   2078   // Evaluate add with square = {10, 20, 30, 40}.
   2079   HloEvaluator evaluator;
   2080   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
   2081   TF_ASSERT_OK_AND_ASSIGN(
   2082       Literal result,
   2083       evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
   2084   EXPECT_TRUE(LiteralTestUtil::Equal(
   2085       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
   2086 }
   2087 
   2088 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
   2089   const char* hlo_text = R"(
   2090 HloModule TensorFlowGatherV1
   2091 
   2092 ENTRY main {
   2093   operand = s32[3,3] parameter(0)
   2094   indices = s32[2] parameter(1)
   2095   ROOT gather = s32[2,3] gather(operand, indices),
   2096       offset_dims={1},
   2097       collapsed_slice_dims={0},
   2098       start_index_map={0},
   2099       index_vector_dim=1,
   2100       slice_sizes={1, 3}
   2101 }
   2102 )";
   2103   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2104   Literal operand =
   2105       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2106   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2107   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2108   EXPECT_TRUE(LiteralTestUtil::Equal(
   2109       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), result));
   2110 }
   2111 
   2112 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
   2113   const char* hlo_text = R"(
   2114 HloModule TensorFlowGatherV2
   2115 
   2116 ENTRY main {
   2117   operand = s32[3,3] parameter(0)
   2118   indices = s32[2] parameter(1)
   2119   ROOT gather = s32[3,2] gather(operand, indices),
   2120       offset_dims={0},
   2121       collapsed_slice_dims={1},
   2122       start_index_map={1},
   2123       index_vector_dim=1,
   2124       slice_sizes={3, 1}
   2125 }
   2126 )";
   2127   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2128   Literal operand =
   2129       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2130   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2131   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2132   EXPECT_TRUE(LiteralTestUtil::Equal(
   2133       LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), result));
   2134 }
   2135 
   2136 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
   2137   const char* hlo_text = R"(
   2138 HloModule TensorFlowGatherMultipleBatchDims
   2139 
   2140 ENTRY main {
   2141   operand = s32[3,3] parameter(0)
   2142   indices = s32[2,2] parameter(1)
   2143   ROOT gather = s32[2,3,2] gather(operand, indices),
   2144       offset_dims={1},
   2145       collapsed_slice_dims={1},
   2146       start_index_map={1},
   2147       index_vector_dim=2,
   2148       slice_sizes={3, 1}
   2149 }
   2150 )";
   2151   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2152   Literal operand =
   2153       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2154   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
   2155   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2156   EXPECT_TRUE(LiteralTestUtil::Equal(
   2157       LiteralUtil::CreateR3<int32>(
   2158           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
   2159       result));
   2160 }
   2161 
   2162 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
   2163   const char* hlo_text = R"(
   2164 HloModule TensorFlowGatherNd
   2165 
   2166 ENTRY main {
   2167   operand = s32[3,3,2] parameter(0)
   2168   indices = s32[2,2] parameter(1)
   2169   ROOT gather = s32[2,2] gather(operand, indices),
   2170       offset_dims={1},
   2171       collapsed_slice_dims={0,1},
   2172       start_index_map={0,1},
   2173       index_vector_dim=1,
   2174       slice_sizes={1,1,2}
   2175 }
   2176 )";
   2177   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2178   Literal operand =
   2179       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
   2180                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
   2181                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2182   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   2183   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2184   EXPECT_TRUE(LiteralTestUtil::Equal(
   2185       LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), result));
   2186 }
   2187 
   2188 TEST_F(HloEvaluatorTest,
   2189        EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
   2190   const char* hlo_text = R"(
   2191 HloModule TensorFlowGatherNd
   2192 
   2193 ENTRY main {
   2194   operand = s32[3,3,2] parameter(0)
   2195   indices = s32[2,2] parameter(1)
   2196   ROOT gather = s32[2,2] gather(operand, indices),
   2197       offset_dims={1},
   2198       collapsed_slice_dims={0,1},
   2199       start_index_map={0,1},
   2200       index_vector_dim=0,
   2201       slice_sizes={1,1,2}
   2202 }
   2203 )";
   2204   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2205   Literal operand =
   2206       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
   2207                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
   2208                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2209   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   2210   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2211   EXPECT_TRUE(LiteralTestUtil::Equal(
   2212       LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), result));
   2213 }
   2214 
   2215 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
   2216   const char* hlo_text = R"(
   2217 HloModule DynamicSlice
   2218 
   2219 ENTRY main {
   2220   operand = s32[3,3] parameter(0)
   2221   indices = s32[2] parameter(1)
   2222   ROOT gather = s32[1,1] gather(operand, indices),
   2223       offset_dims={0,1},
   2224       collapsed_slice_dims={},
   2225       start_index_map={0,1},
   2226       index_vector_dim=0,
   2227       slice_sizes={1,1}
   2228 }
   2229 )";
   2230   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2231   Literal operand =
   2232       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2233   Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
   2234   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2235   EXPECT_TRUE(
   2236       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), result));
   2237 }
   2238 
   2239 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
   2240   const char* hlo_text = R"(
   2241 HloModule BatchDynamicSlice
   2242 
   2243 ENTRY main {
   2244   operand = s32[3,3] parameter(0)
   2245   indices = s32[2,2] parameter(1)
   2246   ROOT gather = s32[2,1,1] gather(operand, indices),
   2247       offset_dims={1,2},
   2248       collapsed_slice_dims={},
   2249       start_index_map={0,1},
   2250       index_vector_dim=0,
   2251       slice_sizes={1,1}
   2252 }
   2253 )";
   2254   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2255   Literal operand =
   2256       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2257   Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
   2258   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2259   EXPECT_TRUE(LiteralTestUtil::Equal(
   2260       LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), result));
   2261 }
   2262 
   2263 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
   2264   const char* hlo_text = R"(
   2265 HloModule TensorFlowGatherV1
   2266 
   2267 ENTRY main {
   2268   operand = s32[3,0] parameter(0)
   2269   indices = s32[2] parameter(1)
   2270   ROOT gather = s32[2,0] gather(operand, indices),
   2271       offset_dims={1},
   2272       collapsed_slice_dims={0},
   2273       start_index_map={0},
   2274       index_vector_dim=1,
   2275       slice_sizes={1, 0}
   2276 }
   2277 )";
   2278   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2279   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
   2280   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2281   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2282   EXPECT_TRUE(
   2283       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), result));
   2284 }
   2285 
   2286 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
   2287   const string hlo_text = R"(
   2288 HloModule GatherXd
   2289 
   2290 ENTRY main {
   2291   operand = s32[3] parameter(0)
   2292   indices = s32[2,2,1] parameter(1)
   2293   ROOT gather = s32[2,2] gather(operand, indices),
   2294       offset_dims={},
   2295       collapsed_slice_dims={0},
   2296       start_index_map={0},
   2297       index_vector_dim=2,
   2298       slice_sizes={1}
   2299 }
   2300 )";
   2301   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2302 
   2303   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
   2304   Literal start_indices =
   2305       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
   2306   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
   2307   EXPECT_TRUE(LiteralTestUtil::Equal(
   2308       LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), result));
   2309 }
   2310 
   2311 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
   2312   const char* hlo_text = R"(
   2313 HloModule TensorFlowScatterV1
   2314 
   2315 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2316   lhs = s32[] parameter(0)
   2317   ROOT rhs = s32[] parameter(1)
   2318 }
   2319 
   2320 ENTRY main {
   2321   operand = s32[3,3] parameter(0)
   2322   indices = s32[2] parameter(1)
   2323   updates = s32[2,3] parameter(2)
   2324   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2325       to_apply=update_s32,
   2326       update_window_dims={1},
   2327       inserted_window_dims={0},
   2328       scatter_dims_to_operand_dims={0},
   2329       index_vector_dim=1
   2330 }
   2331 )";
   2332   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2333   Literal operand =
   2334       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2335   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2336   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   2337   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2338                           Evaluate({&operand, &scatter_indices, &updates}));
   2339   EXPECT_TRUE(LiteralTestUtil::Equal(
   2340       LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
   2341       result));
   2342 }
   2343 
   2344 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
   2345   const char* hlo_text = R"(
   2346 HloModule TensorFlowScatterV2
   2347 
   2348 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2349   lhs = s32[] parameter(0)
   2350   ROOT rhs = s32[] parameter(1)
   2351 }
   2352 
   2353 ENTRY main {
   2354   operand = s32[3,3] parameter(0)
   2355   indices = s32[2] parameter(1)
   2356   updates = s32[3,2] parameter(2)
   2357   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2358       to_apply=update_s32,
   2359       update_window_dims={0},
   2360       inserted_window_dims={1},
   2361       scatter_dims_to_operand_dims={1},
   2362       index_vector_dim=1
   2363 }
   2364 )";
   2365   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2366   Literal operand =
   2367       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2368   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2369   Literal updates =
   2370       LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
   2371   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2372                           Evaluate({&operand, &scatter_indices, &updates}));
   2373   EXPECT_TRUE(LiteralTestUtil::Equal(
   2374       LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
   2375       result));
   2376 }
   2377 
   2378 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
   2379   const char* hlo_text = R"(
   2380 HloModule TensorFlowScatter
   2381 
   2382 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2383   lhs = s32[] parameter(0)
   2384   rhs = s32[] parameter(1)
   2385   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
   2386 }
   2387 
   2388 ENTRY main {
   2389   operand = s32[3,3] parameter(0)
   2390   indices = s32[2] parameter(1)
   2391   updates = s32[2,3] parameter(2)
   2392   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2393       to_apply=add_s32,
   2394       update_window_dims={1},
   2395       inserted_window_dims={0},
   2396       scatter_dims_to_operand_dims={0},
   2397       index_vector_dim=1
   2398 }
   2399 )";
   2400   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2401   Literal operand =
   2402       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2403   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2404   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   2405   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2406                           Evaluate({&operand, &scatter_indices, &updates}));
   2407   EXPECT_TRUE(LiteralTestUtil::Equal(
   2408       LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
   2409       result));
   2410 }
   2411 
   2412 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
   2413   const char* hlo_text = R"(
   2414 HloModule TensorFlowScatter
   2415 
   2416 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2417   lhs = s32[] parameter(0)
   2418   rhs = s32[] parameter(1)
   2419   ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
   2420 }
   2421 
   2422 ENTRY main {
   2423   operand = s32[3,3] parameter(0)
   2424   indices = s32[2] parameter(1)
   2425   updates = s32[2,3] parameter(2)
   2426   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2427       to_apply=mul_s32,
   2428       update_window_dims={1},
   2429       inserted_window_dims={0},
   2430       scatter_dims_to_operand_dims={0},
   2431       index_vector_dim=1
   2432 }
   2433 )";
   2434   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2435   Literal operand =
   2436       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2437   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2438   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   2439   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2440                           Evaluate({&operand, &scatter_indices, &updates}));
   2441   EXPECT_TRUE(LiteralTestUtil::Equal(
   2442       LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
   2443       result));
   2444 }
   2445 
   2446 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
   2447   const char* hlo_text = R"(
   2448 HloModule TensorFlowScatter
   2449 
   2450 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
   2451   lhs = f32[] parameter(0)
   2452   rhs = f32[] parameter(1)
   2453   ROOT add = f32[] add(f32[] lhs, f32[] rhs)
   2454 }
   2455 
   2456 ENTRY main {
   2457   operand = f32[3,3] parameter(0)
   2458   indices = s32[2] parameter(1)
   2459   updates = f32[2,3] parameter(2)
   2460   ROOT scatter = f32[3,3] scatter(operand, indices, updates),
   2461       to_apply=add_f32,
   2462       update_window_dims={1},
   2463       inserted_window_dims={0},
   2464       scatter_dims_to_operand_dims={0},
   2465       index_vector_dim=1
   2466 }
   2467 )";
   2468   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2469   Literal operand = LiteralUtil::CreateR2<float>(
   2470       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
   2471   Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
   2472   Literal updates =
   2473       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
   2474   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2475                           Evaluate({&operand, &scatter_indices, &updates}));
   2476   EXPECT_TRUE(LiteralTestUtil::Near(
   2477       LiteralUtil::CreateR2<float>(
   2478           {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
   2479       result, ErrorSpec{0.1, 0.01}));
   2480 }
   2481 
   2482 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
   2483   const char* hlo_text = R"(
   2484 HloModule TensorFlowScatter
   2485 
   2486 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2487   lhs = s32[] parameter(0)
   2488   rhs = s32[] parameter(1)
   2489   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
   2490 }
   2491 
   2492 ENTRY main {
   2493   operand = s32[3,3] parameter(0)
   2494   indices = s32[2] parameter(1)
   2495   updates = s32[2,3] parameter(2)
   2496   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2497       to_apply=add_s32,
   2498       update_window_dims={1},
   2499       inserted_window_dims={0},
   2500       scatter_dims_to_operand_dims={0},
   2501       index_vector_dim=1
   2502 }
   2503 )";
   2504   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2505   Literal operand =
   2506       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2507   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
   2508   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   2509   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2510                           Evaluate({&operand, &scatter_indices, &updates}));
   2511   EXPECT_TRUE(LiteralTestUtil::Equal(
   2512       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
   2513       result));
   2514 }
   2515 
   2516 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
   2517   const char* hlo_text = R"(
   2518 HloModule TensorFlowScatterMultipleBatchDims
   2519 
   2520 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2521   lhs = s32[] parameter(0)
   2522   rhs = s32[] parameter(1)
   2523   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
   2524 }
   2525 
   2526 ENTRY main {
   2527   operand = s32[3,3] parameter(0)
   2528   indices = s32[2,2] parameter(1)
   2529   updates = s32[2,3,2] parameter(2)
   2530   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2531       to_apply=add_s32,
   2532       update_window_dims={1},
   2533       inserted_window_dims={1},
   2534       scatter_dims_to_operand_dims={1},
   2535       index_vector_dim=2
   2536 }
   2537 )";
   2538   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2539   Literal operand =
   2540       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2541   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
   2542   Literal updates = LiteralUtil::CreateR3<int32>(
   2543       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
   2544   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2545                           Evaluate({&operand, &scatter_indices, &updates}));
   2546   EXPECT_TRUE(LiteralTestUtil::Equal(
   2547       LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
   2548       result));
   2549 }
   2550 
   2551 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
   2552   const char* hlo_text = R"(
   2553 HloModule TensorFlowScatterNd
   2554 
   2555 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2556   lhs = s32[] parameter(0)
   2557   ROOT rhs = s32[] parameter(1)
   2558 }
   2559 
   2560 ENTRY main {
   2561   operand = s32[3,3,2] parameter(0)
   2562   indices = s32[2,2] parameter(1)
   2563   updates = s32[2,2] parameter(2)
   2564   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
   2565       to_apply=update_s32,
   2566       update_window_dims={1},
   2567       inserted_window_dims={0,1},
   2568       scatter_dims_to_operand_dims={0,1},
   2569       index_vector_dim=1
   2570 }
   2571 )";
   2572   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2573   Literal operand =
   2574       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
   2575                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
   2576                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2577   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   2578   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
   2579   Literal expected =
   2580       LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
   2581                                     {{-40, 40}, {-5, 5}, {-6, 6}},  //
   2582                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2583   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2584                           Evaluate({&operand, &scatter_indices, &updates}));
   2585   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2586 }
   2587 
   2588 TEST_F(HloEvaluatorTest,
   2589        EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
   2590   const char* hlo_text = R"(
   2591 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
   2592 
   2593 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2594   lhs = s32[] parameter(0)
   2595   ROOT rhs = s32[] parameter(1)
   2596 }
   2597 
   2598 ENTRY main {
   2599   operand = s32[3,3,2] parameter(0)
   2600   indices = s32[2,2] parameter(1)
   2601   updates = s32[2,2] parameter(2)
   2602   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
   2603       to_apply=update_s32,
   2604       update_window_dims={1},
   2605       inserted_window_dims={0,1},
   2606       scatter_dims_to_operand_dims={0,1},
   2607       index_vector_dim=0
   2608 }
   2609 )";
   2610   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2611   Literal operand =
   2612       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
   2613                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
   2614                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2615   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   2616   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
   2617   Literal expected =
   2618       LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}},  //
   2619                                     {{-4, 4}, {-5, 5}, {-6, 6}},      //
   2620                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2621   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2622                           Evaluate({&operand, &scatter_indices, &updates}));
   2623   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2624 }
   2625 
   2626 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
   2627   const char* hlo_text = R"(
   2628 HloModule DynamicUpdateSlice
   2629 
   2630 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2631   lhs = s32[] parameter(0)
   2632   ROOT rhs = s32[] parameter(1)
   2633 }
   2634 
   2635 ENTRY main {
   2636   operand = s32[3,3] parameter(0)
   2637   indices = s32[2] parameter(1)
   2638   updates = s32[1,1] parameter(2)
   2639   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2640       to_apply=update_s32,
   2641       update_window_dims={0,1},
   2642       inserted_window_dims={},
   2643       scatter_dims_to_operand_dims={0,1},
   2644       index_vector_dim=0
   2645 }
   2646 )";
   2647   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2648   Literal operand =
   2649       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2650   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
   2651   Literal updates = LiteralUtil::CreateR2<int32>({{10}});
   2652   Literal expected =
   2653       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
   2654   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2655                           Evaluate({&operand, &scatter_indices, &updates}));
   2656   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2657 }
   2658 
   2659 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
   2660   const char* hlo_text = R"(
   2661 HloModule BatchDynamicUpdateSlice
   2662 
   2663 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2664   lhs = s32[] parameter(0)
   2665   ROOT rhs = s32[] parameter(1)
   2666 }
   2667 
   2668 ENTRY main {
   2669   operand = s32[3,3] parameter(0)
   2670   indices = s32[2,2] parameter(1)
   2671   updates = s32[2,1,1] parameter(2)
   2672   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2673       to_apply=update_s32,
   2674       update_window_dims={1,2},
   2675       inserted_window_dims={},
   2676       scatter_dims_to_operand_dims={0,1},
   2677       index_vector_dim=0
   2678 }
   2679 )";
   2680   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2681   Literal operand =
   2682       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2683   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
   2684   Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
   2685   Literal expected =
   2686       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
   2687   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2688                           Evaluate({&operand, &scatter_indices, &updates}));
   2689   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2690 }
   2691 
   2692 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
   2693   const char* hlo_text = R"(
   2694 HloModule TensorFlowScatter_ZeroDimBounds
   2695 
   2696 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2697   lhs = s32[] parameter(0)
   2698   ROOT rhs = s32[] parameter(1)
   2699 }
   2700 
   2701 ENTRY main {
   2702   operand = s32[3,0] parameter(0)
   2703   indices = s32[2] parameter(1)
   2704   updates = s32[2,0] parameter(2)
   2705   ROOT scatter = s32[3,0] scatter(operand, indices, updates),
   2706       to_apply=update_s32,
   2707       update_window_dims={1},
   2708       inserted_window_dims={0},
   2709       scatter_dims_to_operand_dims={0},
   2710       index_vector_dim=1
   2711 }
   2712 )";
   2713   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2714   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
   2715   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
   2716   Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
   2717   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2718                           Evaluate({&operand, &scatter_indices, &updates}));
   2719   EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
   2720 }
   2721 
   2722 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
   2723   const string hlo_text = R"(
   2724 HloModule Scatter_NoUpdateWindowDims
   2725 
   2726 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2727   lhs = s32[] parameter(0)
   2728   rhs = s32[] parameter(1)
   2729   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
   2730 }
   2731 
   2732 ENTRY main {
   2733   operand = s32[3] parameter(0)
   2734   indices = s32[2,2,1] parameter(1)
   2735   updates = s32[2,2] parameter(2)
   2736   ROOT scatter = s32[3] scatter(operand, indices, updates),
   2737       to_apply=add_s32,
   2738       update_window_dims={},
   2739       inserted_window_dims={0},
   2740       scatter_dims_to_operand_dims={0},
   2741       index_vector_dim=2
   2742 }
   2743 )";
   2744   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2745 
   2746   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
   2747   Literal scatter_indices =
   2748       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
   2749   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
   2750   Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
   2751   TF_ASSERT_OK_AND_ASSIGN(Literal result,
   2752                           Evaluate({&operand, &scatter_indices, &updates}));
   2753   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2754 }
   2755 
   2756 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
   2757   const char* hlo_text = R"(
   2758 HloModule TensorFlowScatter_NegativeIndices
   2759 
   2760 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2761   lhs = s32[] parameter(0)
   2762   rhs = s32[] parameter(1)
   2763   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
   2764 }
   2765 
   2766 ENTRY main {
   2767   operand = s32[3,3] parameter(0)
   2768   indices = s32[2] parameter(1)
   2769   updates = s32[2,3] parameter(2)
   2770   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
   2771       to_apply=add_s32,
   2772       update_window_dims={1},
   2773       inserted_window_dims={0},
   2774       scatter_dims_to_operand_dims={0},
   2775       index_vector_dim=1
   2776 }
   2777 )";
   2778   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
   2779                           ParseAndReturnVerifiedModule(hlo_text));
   2780   Literal operand =
   2781       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2782   // No updates should happen for the negative indices.
   2783   Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
   2784   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   2785   EXPECT_TRUE(LiteralTestUtil::Equal(
   2786       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
   2787       EvaluateWithModule(module.get(),
   2788                          {&operand, &scatter_indices, &updates})));
   2789 }
   2790 
   2791 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
   2792   const string hlo_text = R"(
   2793 HloModule BatchDynamicUpdateSlice
   2794 
   2795 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2796   lhs = s32[] parameter(0)
   2797   ROOT rhs = s32[] parameter(1)
   2798 }
   2799 
   2800 ENTRY main {
   2801   operand = s32[3,3]{1,0} parameter(0)
   2802   indices = s32[6,2]{1,0} parameter(1)
   2803   updates = s32[6,1,1]{2,1,0} parameter(2)
   2804   ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
   2805       to_apply=update_s32,
   2806       update_window_dims={1,2},
   2807       inserted_window_dims={},
   2808       scatter_dims_to_operand_dims={0,1},
   2809       index_vector_dim=1
   2810 }
   2811 )";
   2812   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
   2813                           ParseAndReturnVerifiedModule(hlo_text));
   2814   Literal operand =
   2815       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   2816   // No updates should happen for the OOB indices.
   2817   Literal scatter_indices = LiteralUtil::CreateR2<int32>(
   2818       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
   2819   Literal updates = LiteralUtil::CreateR3<int32>(
   2820       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
   2821   EXPECT_TRUE(LiteralTestUtil::Equal(
   2822       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
   2823       EvaluateWithModule(module.get(),
   2824                          {&operand, &scatter_indices, &updates})));
   2825 }
   2826 
   2827 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
   2828   const char* hlo_text = R"(
   2829 HloModule TensorFlowScatterNd_OobUpdateWindow
   2830 
   2831 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
   2832   lhs = s32[] parameter(0)
   2833   ROOT rhs = s32[] parameter(1)
   2834 }
   2835 
   2836 ENTRY main {
   2837   operand = s32[3,3,2] parameter(0)
   2838   indices = s32[1,2] parameter(1)
   2839   updates = s32[1,2,2] parameter(2)
   2840   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
   2841       to_apply=update_s32,
   2842       update_window_dims={1,2},
   2843       inserted_window_dims={0},
   2844       scatter_dims_to_operand_dims={0,1},
   2845       index_vector_dim=1
   2846 }
   2847 )";
   2848   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
   2849                           ParseAndReturnVerifiedModule(hlo_text));
   2850   Literal operand =
   2851       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
   2852                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
   2853                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   2854   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
   2855   Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
   2856   // Given the update window size of 2,2 and the index of 0,2, the update window
   2857   // will be OOB. So, nothing should be updated.
   2858   Literal expected = operand.Clone();
   2859   EXPECT_TRUE(LiteralTestUtil::Equal(
   2860       expected, EvaluateWithModule(module.get(),
   2861                                    {&operand, &scatter_indices, &updates})));
   2862 }
   2863 
   2864 // Verifies that HloEvaluator evaluates a HLO instruction that performs
   2865 // element-wise comparison with 2 bfloat16 operands.
   2866 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
   2867   // lhs >= rhs
   2868   auto lhs = LiteralUtil::CreateR2<bfloat16>(
   2869       {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
   2870        {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
   2871   auto rhs = LiteralUtil::CreateR2<bfloat16>(
   2872       {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
   2873        {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
   2874   auto expected =
   2875       LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
   2876 
   2877   HloComputation::Builder b(TestName());
   2878   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
   2879   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
   2880   b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
   2881                                                  ComparisonDirection::kGe));
   2882   m_->AddEntryComputation(b.Build());
   2883 
   2884   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
   2885   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2886 }
   2887 
   2888 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
   2889   const string hlo_text = R"(
   2890 HloModule Bf16Reduction
   2891 
   2892 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
   2893   lhs = bf16[] parameter(0)
   2894   rhs = bf16[] parameter(1)
   2895   ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
   2896 }
   2897 
   2898 ENTRY main {
   2899   arg0 = bf16[4]{0} parameter(0)
   2900   init = bf16[] constant(0)
   2901   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
   2902 }
   2903 )";
   2904   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2905 
   2906   Literal arg = LiteralUtil::CreateR1<bfloat16>(
   2907       {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
   2908   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
   2909   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
   2910   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   2911 }
   2912 
   2913 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
   2914   // Infeed triggers unimplemented error within HandleCall, and we verify that
   2915   // the Evaluator does fail in such case.
   2916   const string hlo_text = R"(
   2917 HloModule DontFailOnCall
   2918 
   2919 call {
   2920   token0 = token[] after-all()
   2921   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
   2922 }
   2923 
   2924 ENTRY main {
   2925   ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call
   2926 }
   2927 )";
   2928   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2929   auto statusor = Evaluate();
   2930   EXPECT_FALSE(statusor.status().ok());
   2931 }
   2932 
   2933 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
   2934   // Infeed triggers unimplemented error within HandleFusion, and we verify that
   2935   // the Evaluator does fail in such case.
   2936   const string hlo_text = R"(
   2937 HloModule DontFailOnFusion
   2938 
   2939 fused_computation {
   2940   token0 = token[] after-all()
   2941   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
   2942 }
   2943 
   2944 ENTRY main {
   2945   ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation
   2946 }
   2947 )";
   2948   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2949   auto statusor = Evaluate();
   2950   EXPECT_FALSE(statusor.status().ok());
   2951 }
   2952 
   2953 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
   2954   // Regression test for b/114735354.
   2955   const string hlo_text = R"(
   2956 HloModule SliceWithDifferentLayout
   2957 
   2958 ENTRY main {
   2959   arg = f32[2,2,2]{0,1,2} parameter(0)
   2960   ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
   2961 }
   2962 )";
   2963   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2964 
   2965   Literal arg = LiteralUtil::CreateR3WithLayout<float>(
   2966       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
   2967       LayoutUtil::MakeLayout({0, 1, 2}));
   2968   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
   2969   EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
   2970 }
   2971 
   2972 TEST_P(HloEvaluatorBf16Test, Bitcast) {
   2973   // Regression test for b/114735354.
   2974   constexpr absl::string_view hlo_text_base = R"(
   2975 HloModule Bitcast
   2976 
   2977 ENTRY main {
   2978   param = %s[32,121]{1,0} parameter(0)
   2979   ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
   2980 }
   2981 )";
   2982   string hlo_text;
   2983   if (use_bfloat16_) {
   2984     hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
   2985   } else {
   2986     hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
   2987   }
   2988   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   2989   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   2990   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
   2991   if (use_bfloat16_) {
   2992     EXPECT_TRUE(
   2993         absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
   2994   } else {
   2995     EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
   2996   }
   2997 }
   2998 
   2999 // Check that s32 under/overflow doesn't trigger a ubsan failure.
   3000 TEST_F(HloEvaluatorTest, Int32Overflow) {
   3001   constexpr absl::string_view hlo_text = R"(
   3002 HloModule Test
   3003 
   3004 ENTRY main {
   3005   c1 = s32[] constant(1073741824)  // 2^30
   3006   sum = s32[] add(c1, c1)  // 2^31, i.e. INT_MIN
   3007 
   3008   c2 = s32[] constant(-2147483648)  // -2^31
   3009   sub = s32[] subtract(c2, c1)  // -2^31 - 2^30, underflows
   3010 
   3011   mul = s32[] multiply(c1, c1)
   3012   ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
   3013 }
   3014 )";
   3015   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3016   TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
   3017   std::vector<Literal> actual = literal.DecomposeTuple();
   3018   ASSERT_EQ(actual.size(), 3);
   3019 
   3020   uint32 pow30 = uint32{1} << 30;
   3021   uint32 pow31 = uint32{1} << 31;
   3022   EXPECT_EQ(actual[0].GetFirstElement<int32>(), static_cast<int32>(pow31));
   3023   EXPECT_EQ(actual[1].GetFirstElement<int32>(),
   3024             static_cast<int32>(-(pow31 + pow30)));
   3025   EXPECT_EQ(actual[2].GetFirstElement<int32>(),
   3026             static_cast<int32>(pow31 * pow31));
   3027 }
   3028 
   3029 TEST_F(HloEvaluatorTest, GetDimensionSize) {
   3030   constexpr absl::string_view hlo_text = R"(
   3031 HloModule Test
   3032 
   3033 ENTRY main {
   3034   size = u32[] parameter(0)
   3035 
   3036   data = s32[4] parameter(1)
   3037 
   3038   sum = s32[4] add(data, data)
   3039 
   3040   ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
   3041 }
   3042 )";
   3043   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3044 
   3045   // Set up dynamic parameter binding.
   3046   TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
   3047       DynamicParameterBinding::DynamicParameter{0, {}},
   3048       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
   3049 
   3050   TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
   3051                           DynamicDimensionInference::Run(m_.get()));
   3052 
   3053   evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
   3054   Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
   3055   Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
   3056 
   3057   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
   3058 
   3059   EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
   3060 }
   3061 
   3062 // Check that we get a useful error if we pass inputs of the wrong shape.
   3063 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
   3064   constexpr absl::string_view hlo_text = R"(
   3065 HloModule Test
   3066 
   3067 ENTRY main {
   3068   p0 = s32[1] parameter(0)
   3069   ROOT sum = s32[1] add(p0, p0)
   3070 }
   3071 )";
   3072   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3073   Literal input_wrong_shape = LiteralUtil::CreateR1<int32>({0, 1});
   3074 
   3075   EXPECT_EQ(HloEvaluator()
   3076                 .Evaluate(*m_, {&input_wrong_shape})
   3077                 .status()
   3078                 .error_message(),
   3079             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
   3080             "but arg was s32[2].");
   3081   EXPECT_EQ(HloEvaluator()
   3082                 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
   3083                 .status()
   3084                 .error_message(),
   3085             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
   3086             "but arg was s32[2].");
   3087 }
   3088 
   3089 // Check that we get a useful error if we pass too many or too few inputs.
   3090 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
   3091   constexpr absl::string_view hlo_text = R"(
   3092 HloModule Test
   3093 
   3094 ENTRY main {
   3095   p0 = s32[1] parameter(0)
   3096   ROOT sum = s32[1] add(p0, p0)
   3097 }
   3098 )";
   3099   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3100   Literal input = LiteralUtil::CreateR1<int32>({0});
   3101 
   3102   EXPECT_EQ(
   3103       HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
   3104       "Expected 1 argument, but got 2.");
   3105   EXPECT_EQ(HloEvaluator()
   3106                 .Evaluate(*m_->entry_computation(), {&input, &input})
   3107                 .status()
   3108                 .error_message(),
   3109             "Expected 1 argument, but got 2.");
   3110 }
   3111 
   3112 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
   3113   constexpr absl::string_view hlo_text = R"(
   3114     HloModule FusionInputLayout
   3115 
   3116     fused_computation {
   3117       param_0 = f32[20,20]{0,1} parameter(0)
   3118       ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
   3119     }
   3120 
   3121     ENTRY kernel_entry {
   3122       parameter.0 = f32[20,20]{0,1} parameter(0)
   3123       ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
   3124         kind=kLoop, calls=fused_computation
   3125     })";
   3126 
   3127   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3128   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3129 
   3130   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
   3131   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
   3132 }
   3133 
   3134 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
   3135   constexpr absl::string_view hlo_text = R"(
   3136     HloModule FusionOutputLayout
   3137 
   3138     fused_computation {
   3139       param_0 = f32[20,20]{1,0} parameter(0)
   3140       ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
   3141     }
   3142 
   3143     ENTRY kernel_entry {
   3144       parameter.0 = f32[20,20]{1,0} parameter(0)
   3145       ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
   3146         kind=kLoop, calls=fused_computation
   3147     })";
   3148 
   3149   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3150   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3151   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
   3152   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
   3153 }
   3154 
   3155 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
   3156   constexpr absl::string_view hlo_text = R"(
   3157     HloModule MOFusionOutputLayout
   3158 
   3159     fused_computation {
   3160       param_0 = f32[20,20]{1,0} parameter(0)
   3161       bitcast = f32[20,20]{0,1} bitcast(param_0)
   3162       ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
   3163     }
   3164 
   3165     ENTRY kernel_entry {
   3166       parameter.0 = f32[20,20]{1,0} parameter(0)
   3167       ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
   3168         kind=kLoop, calls=fused_computation
   3169     })";
   3170 
   3171   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3172   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3173   TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
   3174   std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
   3175   EXPECT_TRUE(
   3176       absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
   3177 }
   3178 
   3179 // Tests that custom_calls fail to evaluate when no handler is specified.
   3180 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
   3181   constexpr absl::string_view hlo_text = R"(
   3182     HloModule EvaluateCustomCall_NoHandler
   3183     ENTRY kernel_entry {
   3184       parameter.0 = u32[2,2]{1,0} parameter(0)
   3185       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
   3186           custom_call_target="_my_custom_call"
   3187     }
   3188   )";
   3189 
   3190   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3191   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3192   EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
   3193             ::tensorflow::error::UNIMPLEMENTED);
   3194 }
   3195 
   3196 // Tests when a custom_call handler returns an error.
   3197 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
   3198   constexpr absl::string_view hlo_text = R"(
   3199     HloModule EvaluateCustomCall_HandlerError
   3200     ENTRY kernel_entry {
   3201       parameter.0 = u32[2,2]{1,0} parameter(0)
   3202       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
   3203           custom_call_target="_my_custom_call"
   3204     }
   3205   )";
   3206 
   3207   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3208   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3209   HloEvaluator evaluator;
   3210   evaluator.set_custom_call_handler(
   3211       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
   3212         return InternalError("Test error");
   3213       });
   3214   EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
   3215             ::tensorflow::error::INTERNAL);
   3216 }
   3217 
   3218 // Tests the custom_call handler on calls with many inputs.
   3219 // We sum the operands so that we can verify the operand and output literals
   3220 // are properly mapped for access.
   3221 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
   3222   constexpr absl::string_view hlo_text = R"(
   3223     HloModule EvaluateCustomCall_ManyInputs
   3224     ENTRY kernel_entry {
   3225       parameter.0 = u32[1]{0} parameter(0)
   3226       parameter.1 = u32[1]{0} parameter(1)
   3227       ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
   3228           custom_call_target="_my_custom_call"
   3229     }
   3230   )";
   3231 
   3232   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3233   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
   3234   HloEvaluator evaluator;
   3235   evaluator.set_custom_call_handler(
   3236       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
   3237         EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
   3238         EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
   3239         EXPECT_EQ(2, custom_call->operand_count());
   3240         EXPECT_EQ(2, operands.size());
   3241         auto output = Literal::CreateFromShape(custom_call->shape());
   3242         auto operand0_data = operands[0]->data<uint32>();
   3243         auto operand1_data = operands[1]->data<uint32>();
   3244         auto output_data = output.data<uint32>();
   3245         output_data[0] = operand0_data[0] + operand1_data[0];
   3246         return output;
   3247       });
   3248   TF_ASSERT_OK_AND_ASSIGN(
   3249       Literal actual_literal,
   3250       evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
   3251   auto arg0_data = args[0].data<uint32>();
   3252   auto arg1_data = args[1].data<uint32>();
   3253   std::vector<uint32> expected_data = {arg0_data[0] + arg1_data[0]};
   3254   EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32>()));
   3255 }
   3256 
   3257 TEST_F(HloEvaluatorTest, IsFiniteF16) {
   3258   constexpr absl::string_view hlo_text = R"(
   3259   HloModule test
   3260 
   3261   ENTRY IsFiniteTest {
   3262     c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
   3263     ROOT is-finite = pred[6] is-finite(c)
   3264   })";
   3265 
   3266   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3267   TF_ASSERT_OK_AND_ASSIGN(
   3268       Literal actual_literal,
   3269       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
   3270   EXPECT_THAT(actual_literal.data<bool>(),
   3271               ::testing::ElementsAre(false, true, false, true, false, false));
   3272 }
   3273 
   3274 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
   3275   constexpr absl::string_view hlo_text = R"(
   3276   HloModule test
   3277 
   3278   ENTRY IsFiniteTest {
   3279     c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
   3280     ROOT is-finite = pred[6] is-finite(c)
   3281   })";
   3282 
   3283   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   3284   TF_ASSERT_OK_AND_ASSIGN(
   3285       Literal actual_literal,
   3286       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
   3287   EXPECT_THAT(actual_literal.data<bool>(),
   3288               ::testing::ElementsAre(false, true, false, true, false, false));
   3289 }
   3290 
   3291 }  // namespace
   3292 }  // namespace xla
   3293