Home | History | Annotate | Download | only in client
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/xla_builder.h"
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/compiler/xla/client/xla_computation.h"
     21 #include "tensorflow/compiler/xla/debug_options_flags.h"
     22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     23 #include "tensorflow/compiler/xla/service/hlo_module.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/status_macros.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/test_helpers.h"
     28 #include "tensorflow/compiler/xla/util.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 
     31 namespace xla {
     32 
     33 namespace {
     34 
     35 namespace op = xla::testing::opcode_matchers;
     36 
     37 using ::testing::HasSubstr;
     38 
     39 // TODO(b/74197823): Move the tests to service/.
     40 class XlaBuilderTest : public ::testing::Test {
     41  protected:
     42   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
     43     TF_ASSIGN_OR_RETURN(XlaComputation computation,
     44                         b->Build(/*remove_dynamic_dimensions=*/false));
     45     const HloModuleProto& proto = computation.proto();
     46     TF_ASSIGN_OR_RETURN(const auto& config,
     47                         HloModule::CreateModuleConfigFromProto(
     48                             proto, GetDebugOptionsFromFlags()));
     49     return HloModule::CreateFromProto(proto, config);
     50   }
     51 
     52   // Overload which explicitly specifies the root instruction.
     53   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
     54                                                       XlaOp root) {
     55     TF_ASSIGN_OR_RETURN(XlaComputation computation,
     56                         b->Build(root, /*remove_dynamic_dimensions=*/false));
     57     const HloModuleProto& proto = computation.proto();
     58     TF_ASSIGN_OR_RETURN(const auto& config,
     59                         HloModule::CreateModuleConfigFromProto(
     60                             proto, GetDebugOptionsFromFlags()));
     61     return HloModule::CreateFromProto(proto, config);
     62   }
     63 
     64   // Returns the name of the test currently being run.
     65   string TestName() const {
     66     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
     67   }
     68 };
     69 
     70 TEST_F(XlaBuilderTest, OnePlusTwo) {
     71   XlaBuilder b(TestName());
     72   Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0));
     73   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
     74   auto root = module->entry_computation()->root_instruction();
     75   EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
     76 }
     77 
     78 TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
     79   auto test_unary_operator =
     80       [&](std::function<XlaOp(XlaOp)> op,
     81           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
     82         XlaBuilder b(TestName());
     83         op(ConstantR0<int32>(&b, 1));
     84         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
     85         auto root = module->entry_computation()->root_instruction();
     86         EXPECT_THAT(root, matches_pattern);
     87       };
     88   test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
     89   test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
     90 }
     91 
     92 TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
     93   auto test_binary_operator =
     94       [&](std::function<XlaOp(XlaOp, XlaOp)> op,
     95           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
     96         XlaBuilder b(TestName());
     97         op(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
     98         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
     99         auto root = module->entry_computation()->root_instruction();
    100         EXPECT_THAT(root, matches_pattern);
    101       };
    102 
    103   test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
    104                        op::Add(op::Constant(), op::Constant()));
    105   test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
    106                        op::Subtract(op::Constant(), op::Constant()));
    107   test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
    108                        op::Multiply(op::Constant(), op::Constant()));
    109   test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
    110                        op::Divide(op::Constant(), op::Constant()));
    111 
    112   test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
    113                        op::And(op::Constant(), op::Constant()));
    114   test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
    115                        op::Or(op::Constant(), op::Constant()));
    116   test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
    117                        op::Xor(op::Constant(), op::Constant()));
    118   test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
    119                        op::ShiftLeft(op::Constant(), op::Constant()));
    120   test_binary_operator(
    121       [](XlaOp x, XlaOp y) { return x >> y; },
    122       op::ShiftRightArithmetic(op::Constant(), op::Constant()));
    123 
    124   auto test_unsigned_binary_operator =
    125       [&](std::function<XlaOp(XlaOp, XlaOp)> op,
    126           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
    127         XlaBuilder b(TestName());
    128         op(ConstantR0<uint32>(&b, 1), ConstantR0<uint32>(&b, 2));
    129         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    130         auto root = module->entry_computation()->root_instruction();
    131         EXPECT_THAT(root, matches_pattern);
    132       };
    133   test_unsigned_binary_operator(
    134       [](XlaOp x, XlaOp y) { return x >> y; },
    135       op::ShiftRightLogical(op::Constant(), op::Constant()));
    136 }
    137 
    138 TEST_F(XlaBuilderTest, VariadicAnd) {
    139   XlaBuilder b(TestName());
    140   Shape s = ShapeUtil::MakeShape(PRED, {});
    141   And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
    142       Parameter(&b, 2, s, "p2"));
    143   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    144   // Don't specify in the test whether And(x, y, z) is right- or
    145   // left-associative; accept either one.
    146   EXPECT_THAT(
    147       module->entry_computation()->root_instruction(),
    148       ::testing::AnyOf(op::And(op::Parameter(0),
    149                                op::And(op::Parameter(1), op::Parameter(2))),
    150                        op::And(op::And(op::Parameter(0), op::Parameter(1)),
    151                                op::Parameter(2))));
    152 }
    153 
    154 TEST_F(XlaBuilderTest, VariadicOr) {
    155   XlaBuilder b(TestName());
    156   Shape s = ShapeUtil::MakeShape(PRED, {});
    157   Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
    158      Parameter(&b, 2, s, "p2"));
    159   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    160   // Don't specify in the test whether Or(x, y, z) is right- or
    161   // left-associative; accept either one.
    162   EXPECT_THAT(
    163       module->entry_computation()->root_instruction(),
    164       ::testing::AnyOf(
    165           op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))),
    166           op::Or(op::Or(op::Parameter(0), op::Parameter(1)),
    167                  op::Parameter(2))));
    168 }
    169 
    170 TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
    171   XlaBuilder b(TestName());
    172   ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2);
    173   auto statusor = b.Build();
    174   ASSERT_FALSE(statusor.ok());
    175   EXPECT_THAT(
    176       statusor.status().error_message(),
    177       HasSubstr("Argument to >> operator does not have an integral type"));
    178 }
    179 
    180 TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
    181   XlaBuilder b(TestName());
    182   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
    183   Add(x, ConstantR0<float>(&b, 1.0));
    184   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    185   auto root = module->entry_computation()->root_instruction();
    186   EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
    187 }
    188 
    189 TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
    190   XlaBuilder b(TestName());
    191   const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
    192   const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
    193   auto x = Parameter(&b, 0, x_shape, "x");
    194   auto y = Parameter(&b, 1, y_shape, "y");
    195   auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1});
    196 
    197   TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
    198   EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
    199 
    200   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    201   auto root = module->entry_computation()->root_instruction();
    202   EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1))));
    203 }
    204 
    205 TEST_F(XlaBuilderTest, XPlusX) {
    206   XlaBuilder b(TestName());
    207   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
    208   Add(x, x);
    209   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    210   auto root = module->entry_computation()->root_instruction();
    211   EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
    212 }
    213 
    214 TEST_F(XlaBuilderTest, ShapeInferenceError) {
    215   XlaBuilder b(TestName());
    216   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
    217   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
    218   Add(x, y);
    219   auto statusor = BuildHloModule(&b);
    220   ASSERT_FALSE(statusor.ok());
    221   EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference"));
    222 }
    223 
    224 TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
    225   XlaBuilder b_call("add");
    226   Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x");
    227 
    228   XlaBuilder b(TestName());
    229   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x");
    230   auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y");
    231   Add(x, y);
    232   auto statusor = BuildHloModule(&b);
    233   ASSERT_FALSE(statusor.ok());
    234   EXPECT_THAT(statusor.status().error_message(),
    235               HasSubstr("parameter 0 already registered"));
    236 }
    237 
    238 TEST_F(XlaBuilderTest, Call) {
    239   XlaBuilder b_call("the_only_to_apply");
    240   auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
    241   auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
    242   Add(p0, p1);
    243   TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
    244   XlaBuilder b(TestName());
    245   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
    246   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
    247   auto one = ConstantR0<float>(&b, 1);
    248   auto two = ConstantR0<float>(&b, 2);
    249   Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
    250   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    251   auto root = module->entry_computation()->root_instruction();
    252   EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
    253                             op::Call(op::Constant(), op::Constant())));
    254 }
    255 
    256 TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
    257   XlaBuilder b(TestName());
    258   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
    259   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
    260   Add(x, y);
    261   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    262 
    263   // Expected:
    264   //
    265   //  x: f32[1,2,3]  y: f32[1,2,1]
    266   //      |               |
    267   //      |          reshape: f32[1,2]
    268   //      |               |
    269   //      |          broadcast: f32[1,2,3]
    270   //       \             /
    271   //            add
    272   auto root = module->entry_computation()->root_instruction();
    273   EXPECT_THAT(root, op::Add(op::Parameter(0),
    274                             op::Broadcast(op::Reshape(op::Parameter(1)))));
    275 }
    276 
    277 TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
    278   XlaBuilder b(TestName());
    279   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
    280   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
    281   Add(x, y, /*broadcast_dimensions=*/{0, 1});
    282   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    283 
    284   // The binary operation has in-dim broadcast and degenerate broadcast, should
    285   // first do the in-dim broadcast then convert the degnerate broadcast into a
    286   // reshape and a broadcast.
    287   //
    288   // Expected:
    289   //
    290   //  x: f32[2,3]            y: f32[2,1,4]
    291   //      |                        |
    292   //  broadcast: f32[2,3,4]  reshape: f32[2,4]
    293   //      |                        |
    294   //      |                  broadcast: f32[2,3,4]
    295   //       \                      /
    296   //                 add
    297   auto root = module->entry_computation()->root_instruction();
    298   EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)),
    299                             op::Broadcast(op::Reshape(op::Parameter(1)))));
    300 }
    301 
    302 TEST_F(XlaBuilderTest, BroadcastInDim) {
    303   XlaBuilder b(TestName());
    304   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
    305   BroadcastInDim(x, {2, 4, 3},
    306                  /*broadcast_dimensions=*/{0, 2});
    307   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    308   auto root = module->entry_computation()->root_instruction();
    309   EXPECT_THAT(root, op::Broadcast());
    310 }
    311 
    312 TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
    313   XlaBuilder b(TestName());
    314   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
    315   BroadcastInDim(x, {2, 3, 4},
    316                  /*broadcast_dimensions=*/{0, 1, 2});
    317   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    318   EXPECT_THAT(module->entry_computation()->root_instruction(),
    319               op::Broadcast(op::Reshape(op::Broadcast())));
    320 }
    321 
    322 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
    323   XlaBuilder b1("b1");
    324   auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0");
    325   XlaBuilder builder("main");
    326   auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p");
    327   Add(p, p0);
    328   auto statusor = builder.Build();
    329   ASSERT_FALSE(statusor.ok());
    330   EXPECT_THAT(
    331       statusor.status().error_message(),
    332       HasSubstr(
    333           "built by builder 'b1', but is trying to use it in builder 'main'"));
    334 }
    335 
    336 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
    337   XlaBuilder b(TestName());
    338   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
    339   Reshape(x, /*new_sizes=*/{6, 35});
    340   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    341   auto root = module->entry_computation()->root_instruction();
    342   EXPECT_THAT(root, op::Reshape(op::Parameter()));
    343 }
    344 
    345 TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
    346   XlaBuilder b(TestName());
    347   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
    348   Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
    349   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    350   auto root = module->entry_computation()->root_instruction();
    351   EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
    352 }
    353 
    354 TEST_F(XlaBuilderTest, Transpose) {
    355   XlaBuilder b(TestName());
    356   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
    357   Transpose(x, /*permutation=*/{1, 0});
    358   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    359   auto root = module->entry_computation()->root_instruction();
    360   EXPECT_THAT(root, op::Transpose(op::Parameter()));
    361 }
    362 
    363 TEST_F(XlaBuilderTest, AllToAll) {
    364   XlaBuilder b(TestName());
    365   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
    366   AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
    367            /*split_count=*/2);
    368   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    369   auto root = module->entry_computation()->root_instruction();
    370 
    371   // AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
    372   EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
    373   EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
    374   EXPECT_TRUE(
    375       ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
    376 }
    377 
    378 TEST_F(XlaBuilderTest, CollectivePermute) {
    379   XlaBuilder b(TestName());
    380   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
    381   CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
    382   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    383   auto root = module->entry_computation()->root_instruction();
    384   EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
    385 }
    386 
    387 TEST_F(XlaBuilderTest, GetDimensionSize) {
    388   XlaBuilder b(TestName());
    389   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
    390   GetDimensionSize(x, 1);
    391   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    392   auto root = module->entry_computation()->root_instruction();
    393   EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
    394 }
    395 
    396 TEST_F(XlaBuilderTest, ReportError) {
    397   XlaBuilder b(TestName());
    398   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
    399   Add(b.ReportError(InvalidArgument("a test error")), x);
    400   auto statusor = b.Build();
    401   ASSERT_FALSE(statusor.ok());
    402   EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
    403 }
    404 
    405 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
    406   XlaBuilder b(TestName());
    407   StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0));
    408   Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
    409   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    410   auto root = module->entry_computation()->root_instruction();
    411   EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
    412 }
    413 
    414 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
    415   XlaBuilder b(TestName());
    416   StatusOr<XlaOp> op(InvalidArgument("a test error"));
    417   Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
    418   auto statusor = b.Build();
    419   ASSERT_FALSE(statusor.ok());
    420   EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
    421 }
    422 
    423 TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
    424   XlaBuilder b(TestName());
    425   XlaOp constant = ConstantR0<float>(&b, 1.0);
    426   Add(constant, ConstantR0<float>(&b, 2.0));
    427   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
    428   auto root = module->entry_computation()->root_instruction();
    429   EXPECT_THAT(root, op::Constant());
    430 }
    431 
    432 TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
    433   // Specifying a particular root in Build should still include all entry
    434   // parameters.
    435   XlaBuilder b(TestName());
    436   const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
    437   XlaOp x = Parameter(&b, 0, shape, "x");
    438   XlaOp y = Parameter(&b, 1, shape, "y");
    439   XlaOp z = Parameter(&b, 2, shape, "z");
    440   Add(x, Sub(y, z));
    441   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
    442   auto root = module->entry_computation()->root_instruction();
    443   EXPECT_THAT(root, op::Parameter());
    444   EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
    445   EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
    446 }
    447 
    448 TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
    449   XlaBuilder b(TestName());
    450   XlaBuilder other_b(TestName());
    451   const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
    452 
    453   Parameter(&b, 0, shape, "param");
    454   XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
    455 
    456   Status status = b.Build(other_param).status();
    457   ASSERT_IS_NOT_OK(status);
    458   EXPECT_THAT(
    459       status.error_message(),
    460       ::testing::HasSubstr("root operation is not in this computation"));
    461 }
    462 
    463 TEST_F(XlaBuilderTest, ProtoMatches) {
    464   std::vector<XlaComputation> computations;
    465   for (int i = 0; i < 2; ++i) {
    466     XlaBuilder b_call("the_only_to_apply");
    467     auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
    468     auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
    469     Add(p0, Add(p1, p0));
    470     TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
    471     XlaBuilder b(TestName());
    472     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
    473     auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
    474     auto one = ConstantR0<float>(&b, 1);
    475     auto two = ConstantR0<float>(&b, 2);
    476     Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
    477     computations.push_back(b.Build().ValueOrDie());
    478   }
    479   auto c0_string = computations[0].proto().SerializeAsString();
    480   auto c1_string = computations[1].proto().SerializeAsString();
    481   EXPECT_EQ(c0_string, c1_string);
    482 }
    483 
    484 TEST_F(XlaBuilderTest, DynamicParameter) {
    485   XlaBuilder b(TestName());
    486   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    487       {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})});
    488   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    489   Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1");
    490   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1,
    491                                    /*dynamic_size_param_index=*/{},
    492                                    /*target_param_num=*/0,
    493                                    /*target_param_index=*/{1},
    494                                    /*target_dim_num=*/0));
    495   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0));
    496   const Shape& param_shape = module->entry_computation()
    497                                  ->parameter_instruction(0)
    498                                  ->shape()
    499                                  .tuple_shapes(1);
    500   EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
    501 }
    502 
    503 TEST_F(XlaBuilderTest, DynamicUnary) {
    504   XlaBuilder b(TestName());
    505   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    506       {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
    507   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    508   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    509                                    /*dynamic_size_param_index=*/{1},
    510                                    /*target_param_num=*/0,
    511                                    /*target_param_index=*/{0},
    512                                    /*target_dim_num=*/0));
    513   auto gte = GetTupleElement(p0, 0);
    514   Neg(gte);
    515   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    516   const Shape& result_shape =
    517       module->entry_computation()->root_instruction()->shape();
    518   EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
    519 }
    520 
    521 TEST_F(XlaBuilderTest, DynamicBinary) {
    522   XlaBuilder b(TestName());
    523   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    524       {ShapeUtil::MakeShape(F32, {5}, {true}),
    525        ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
    526   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    527   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    528                                    /*dynamic_size_param_index=*/{2},
    529                                    /*target_param_num=*/0,
    530                                    /*target_param_index=*/{0},
    531                                    /*target_dim_num=*/0));
    532   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    533                                    /*dynamic_size_param_index=*/{2},
    534                                    /*target_param_num=*/0,
    535                                    /*target_param_index=*/{1},
    536                                    /*target_dim_num=*/0));
    537   auto gte0 = GetTupleElement(p0, 0);
    538   auto gte1 = GetTupleElement(p0, 1);
    539   Add(gte0, gte1);
    540   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    541   const Shape& result_shape =
    542       module->entry_computation()->root_instruction()->shape();
    543   EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
    544 }
    545 
    546 TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) {
    547   XlaBuilder b(TestName());
    548   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    549       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
    550        ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
    551   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    552   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    553                                    /*dynamic_size_param_index=*/{2},
    554                                    /*target_param_num=*/0,
    555                                    /*target_param_index=*/{0},
    556                                    /*target_dim_num=*/0));
    557   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    558                                    /*dynamic_size_param_index=*/{2},
    559                                    /*target_param_num=*/0,
    560                                    /*target_param_index=*/{1},
    561                                    /*target_dim_num=*/0));
    562   auto gte0 = GetTupleElement(p0, 0);
    563   auto gte1 = GetTupleElement(p0, 1);
    564   Add(gte0, gte1, {0});
    565   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    566   const Shape& result_shape =
    567       module->entry_computation()->root_instruction()->shape();
    568   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
    569       << result_shape;
    570 }
    571 
    572 TEST_F(XlaBuilderTest, DynamicBroadcast) {
    573   XlaBuilder b(TestName());
    574   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    575       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
    576        ShapeUtil::MakeShape(U32, {})});
    577   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    578   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    579                                    /*dynamic_size_param_index=*/{1},
    580                                    /*target_param_num=*/0,
    581                                    /*target_param_index=*/{0},
    582                                    /*target_dim_num=*/0));
    583   auto gte = GetTupleElement(p0, 0);
    584   BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4},
    585                  /*broadcast_dimensions=*/{1, 2});
    586   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    587   const Shape& result_shape =
    588       module->entry_computation()->root_instruction()->shape();
    589   EXPECT_TRUE(
    590       ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
    591       << result_shape;
    592 }
    593 
    594 TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) {
    595   XlaBuilder b(TestName());
    596   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    597       {ShapeUtil::MakeShape(F32, {10}, {true}),
    598        ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})});
    599   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    600   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    601                                    /*dynamic_size_param_index=*/{1},
    602                                    /*target_param_num=*/0,
    603                                    /*target_param_index=*/{0},
    604                                    /*target_dim_num=*/0));
    605   auto gte0 = GetTupleElement(p0, 0);
    606   auto gte1 = GetTupleElement(p0, 1);
    607   Add(gte0, gte1, /*broadcast_dimensions=*/{0});  // f32[<=10, 15]
    608   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    609   const Shape& result_shape =
    610       module->entry_computation()->root_instruction()->shape();
    611   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
    612       << result_shape;
    613 }
    614 
    615 TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) {
    616   XlaBuilder b(TestName());
    617   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    618       {ShapeUtil::MakeShape(PRED, {10}, {true}),
    619        ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})});
    620   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    621   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    622                                    /*dynamic_size_param_index=*/{1},
    623                                    /*target_param_num=*/0,
    624                                    /*target_param_index=*/{0},
    625                                    /*target_dim_num=*/0));
    626   auto gte0 = GetTupleElement(p0, 0);
    627   auto gte1 = GetTupleElement(p0, 1);
    628 
    629   Select(gte0, gte1, gte1);
    630 
    631   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    632   const Shape& result_shape =
    633       module->entry_computation()->root_instruction()->shape();
    634   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true}))
    635       << result_shape;
    636 }
    637 
    638 TEST_F(XlaBuilderTest, DynamicPad) {
    639   XlaBuilder b(TestName());
    640   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    641       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
    642        ShapeUtil::MakeShape(U32, {})});
    643   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    644   auto pad_val = ConstantR0<float>(&b, -1);
    645   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    646                                    /*dynamic_size_param_index=*/{1},
    647                                    /*target_param_num=*/0,
    648                                    /*target_param_index=*/{0},
    649                                    /*target_dim_num=*/0));
    650   auto gte = GetTupleElement(p0, 0);
    651   PaddingConfig padding_config;
    652   for (int i = 0; i < 2; i++) {
    653     auto dimension = padding_config.add_dimensions();
    654     dimension->set_edge_padding_low(0);
    655     dimension->set_edge_padding_high(0);
    656     dimension->set_interior_padding(0);
    657   }
    658   Pad(gte, pad_val, padding_config);
    659   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    660   const Shape& result_shape =
    661       module->entry_computation()->root_instruction()->shape();
    662   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
    663       << result_shape;
    664 }
    665 
    666 TEST_F(XlaBuilderTest, DynamicConvolution) {
    667   XlaBuilder b(TestName());
    668   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    669       {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}),
    670        ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}),
    671        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
    672   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    673   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    674                                    /*dynamic_size_param_index=*/{2},
    675                                    /*target_param_num=*/0,
    676                                    /*target_param_index=*/{0},
    677                                    /*target_dim_num=*/0));
    678   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    679                                    /*dynamic_size_param_index=*/{3},
    680                                    /*target_param_num=*/0,
    681                                    /*target_param_index=*/{1},
    682                                    /*target_dim_num=*/2));
    683   auto input = GetTupleElement(p0, 0);
    684   auto filter = GetTupleElement(p0, 1);
    685   ConvolutionDimensionNumbers dnums;
    686   dnums.set_input_batch_dimension(0);
    687   dnums.set_output_batch_dimension(0);
    688   dnums.add_input_spatial_dimensions(1);
    689   dnums.add_output_spatial_dimensions(1);
    690   dnums.add_input_spatial_dimensions(2);
    691   dnums.add_output_spatial_dimensions(2);
    692   dnums.set_input_feature_dimension(3);
    693   dnums.set_output_feature_dimension(3);
    694   dnums.add_kernel_spatial_dimensions(0);
    695   dnums.add_kernel_spatial_dimensions(1);
    696   dnums.set_kernel_input_feature_dimension(2);
    697   dnums.set_kernel_output_feature_dimension(3);
    698   ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
    699                             /*feature_group_count=*/1);
    700   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    701   const Shape& result_shape =
    702       module->entry_computation()->root_instruction()->shape();
    703   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
    704                               {true, false, false, false}))
    705       << result_shape;
    706 }
    707 
    708 TEST_F(XlaBuilderTest, DynamicDot) {
    709   XlaBuilder b(TestName());
    710   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    711       {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}),
    712        ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}),
    713        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
    714   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    715   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    716                                    /*dynamic_size_param_index=*/{2},
    717                                    /*target_param_num=*/0,
    718                                    /*target_param_index=*/{0},
    719                                    /*target_dim_num=*/0));
    720   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    721                                    /*dynamic_size_param_index=*/{2},
    722                                    /*target_param_num=*/0,
    723                                    /*target_param_index=*/{1},
    724                                    /*target_dim_num=*/0));
    725   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    726                                    /*dynamic_size_param_index=*/{3},
    727                                    /*target_param_num=*/0,
    728                                    /*target_param_index=*/{0},
    729                                    /*target_dim_num=*/1));
    730 
    731   auto lhs = GetTupleElement(p0, 0);
    732   auto rhs = GetTupleElement(p0, 1);
    733   DotDimensionNumbers dnums;
    734   dnums.add_lhs_contracting_dimensions(2);
    735   dnums.add_rhs_contracting_dimensions(1);
    736   dnums.add_lhs_batch_dimensions(0);
    737   dnums.add_rhs_batch_dimensions(0);
    738   DotGeneral(lhs, rhs, dnums);
    739   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    740   const Shape& result_shape =
    741       module->entry_computation()->root_instruction()->shape();
    742   EXPECT_TRUE(
    743       ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false}))
    744       << result_shape;
    745 }
    746 
    747 TEST_F(XlaBuilderTest, DynamicReduce) {
    748   XlaBuilder b(TestName());
    749   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    750       {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}),
    751        ShapeUtil::MakeShape(U32, {})});
    752   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    753   auto init = ConstantR0<float>(&b, 0);
    754   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    755                                    /*dynamic_size_param_index=*/{1},
    756                                    /*target_param_num=*/0,
    757                                    /*target_param_index=*/{0},
    758                                    /*target_dim_num=*/1));
    759   auto gte = GetTupleElement(p0, 0);
    760   XlaBuilder bsum(TestName());
    761   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
    762       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
    763   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
    764   Reduce(gte, init, sum, {0});
    765   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    766   const Shape& result_shape =
    767       module->entry_computation()->root_instruction()->shape();
    768   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
    769       << result_shape;
    770 }
    771 
    772 TEST_F(XlaBuilderTest, DynamicReduceWindow) {
    773   XlaBuilder b(TestName());
    774   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    775       {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
    776        ShapeUtil::MakeShape(U32, {})});
    777   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    778   auto init = ConstantR0<float>(&b, 0.f);
    779   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    780                                    /*dynamic_size_param_index=*/{1},
    781                                    /*target_param_num=*/0,
    782                                    /*target_param_index=*/{0},
    783                                    /*target_dim_num=*/0));
    784   auto gte = GetTupleElement(p0, 0);
    785   XlaBuilder bsum(TestName());
    786   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
    787       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
    788   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
    789   ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
    790                /*window_strides=*/{1, 1, 1}, Padding::kValid);
    791   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    792   const Shape& result_shape =
    793       module->entry_computation()->root_instruction()->shape();
    794   EXPECT_TRUE(
    795       ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
    796       << result_shape;
    797 }
    798 
    799 TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
    800   XlaBuilder b(TestName());
    801   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    802       {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
    803        ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}),
    804        ShapeUtil::MakeShape(U32, {})});
    805   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    806   auto init = ConstantR0<float>(&b, 0.f);
    807   XlaBuilder bsum(TestName());
    808   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
    809       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
    810   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
    811   XlaBuilder bge(TestName());
    812   Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"),
    813      Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y"));
    814   TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build());
    815 
    816   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    817                                    /*dynamic_size_param_index=*/{2},
    818                                    /*target_param_num=*/0,
    819                                    /*target_param_index=*/{0},
    820                                    /*target_dim_num=*/0));
    821   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    822                                    /*dynamic_size_param_index=*/{2},
    823                                    /*target_param_num=*/0,
    824                                    /*target_param_index=*/{1},
    825                                    /*target_dim_num=*/0));
    826   auto gte0 = GetTupleElement(p0, 0);
    827   auto source = GetTupleElement(p0, 1);
    828   SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source,
    829                    init, sum);
    830   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    831   const Shape& result_shape =
    832       module->entry_computation()->root_instruction()->shape();
    833   EXPECT_TRUE(
    834       ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
    835       << result_shape;
    836 }
    837 
    838 TEST_F(XlaBuilderTest, DynamicReshape) {
    839   XlaBuilder b(TestName());
    840   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    841       {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
    842                             {false, false, true, true, false}),
    843        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
    844   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    845   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    846                                    /*dynamic_size_param_index=*/{1},
    847                                    /*target_param_num=*/0,
    848                                    /*target_param_index=*/{0},
    849                                    /*target_dim_num=*/2));
    850   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    851                                    /*dynamic_size_param_index=*/{2},
    852                                    /*target_param_num=*/0,
    853                                    /*target_param_index=*/{0},
    854                                    /*target_dim_num=*/3));
    855   auto gte = GetTupleElement(p0, 0);  // f32[2, 3, <=4, <=5, 6]
    856   Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3});
    857   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    858   const Shape& result_shape =
    859       module->entry_computation()->root_instruction()->shape();
    860   EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
    861   EXPECT_TRUE(result_shape.is_dynamic_dimension(3));
    862   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
    863                               {false, true, false, true, false, false}))
    864       << result_shape;
    865 }
    866 
    867 TEST_F(XlaBuilderTest, DynamicSelect) {
    868   XlaBuilder b(TestName());
    869   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    870       {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
    871        ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
    872        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
    873   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    874   auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
    875   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    876                                    /*dynamic_size_param_index=*/{2},
    877                                    /*target_param_num=*/0,
    878                                    /*target_param_index=*/{0},
    879                                    /*target_dim_num=*/1));
    880   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    881                                    /*dynamic_size_param_index=*/{3},
    882                                    /*target_param_num=*/0,
    883                                    /*target_param_index=*/{1},
    884                                    /*target_dim_num=*/1));
    885   auto gte0 = GetTupleElement(p0, 0);
    886   auto gte1 = GetTupleElement(p0, 1);
    887   Select(pred, gte0, gte1);
    888   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    889   const Shape& result_shape =
    890       module->entry_computation()->root_instruction()->shape();
    891   EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
    892   EXPECT_FALSE(result_shape.is_dynamic_dimension(2));
    893   EXPECT_TRUE(
    894       ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
    895       << result_shape;
    896 }
    897 
    898 TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
    899   XlaBuilder b(TestName());
    900   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    901       {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
    902        ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}),
    903        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
    904   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    905   auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
    906   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    907                                    /*dynamic_size_param_index=*/{2},
    908                                    /*target_param_num=*/0,
    909                                    /*target_param_index=*/{0},
    910                                    /*target_dim_num=*/1));
    911   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    912                                    /*dynamic_size_param_index=*/{3},
    913                                    /*target_param_num=*/0,
    914                                    /*target_param_index=*/{1},
    915                                    /*target_dim_num=*/2));
    916   auto gte0 = GetTupleElement(p0, 0);  // f32[4,<=5,6]
    917   auto gte1 = GetTupleElement(p0, 1);  // f32[4,5,<=6]
    918   Select(pred, gte0, gte1);
    919   Status status = BuildHloModule(&b).status();
    920   ASSERT_IS_NOT_OK(status);
    921   EXPECT_THAT(status.error_message(),
    922               ::testing::HasSubstr("Operands to select must be the same shape; "
    923                                    "got f32[4,<=5,6] and f32[4,5,<=6]"));
    924 }
    925 
    926 TEST_F(XlaBuilderTest, DynamicTranspose) {
    927   XlaBuilder b(TestName());
    928   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
    929       {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}),
    930        ShapeUtil::MakeShape(U32, {})});
    931   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
    932   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
    933                                    /*dynamic_size_param_index=*/{1},
    934                                    /*target_param_num=*/0,
    935                                    /*target_param_index=*/{0},
    936                                    /*target_dim_num=*/0));
    937   auto gte = GetTupleElement(p0, 0);
    938   Transpose(gte, /*permutation=*/{1, 0});
    939   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
    940   const Shape& result_shape =
    941       module->entry_computation()->root_instruction()->shape();
    942   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true}))
    943       << result_shape;
    944 }
    945 
    946 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
    947   XlaBuilder b(TestName());
    948   AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
    949   Status status = b.Build().status();
    950   ASSERT_IS_NOT_OK(status);
    951   EXPECT_THAT(status.error_message(),
    952               ::testing::HasSubstr("All operands to AfterAll must be tokens"));
    953 }
    954 
    955 TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
    956   XlaBuilder b(TestName());
    957   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
    958   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
    959   auto add = Add(p0, p1);
    960   auto sub = Sub(p0, p1);
    961   auto root = Tuple(&b, {add, sub});
    962 
    963   b.SetUpAlias({1}, 0, {});
    964   b.SetUpAlias({0}, 1, {});
    965 
    966   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
    967 
    968   const HloInputOutputAliasConfig& config = module->input_output_alias_config();
    969   EXPECT_TRUE(config.ParameterHasAlias(0, {}));
    970   EXPECT_TRUE(config.ParameterHasAlias(1, {}));
    971 
    972   auto alias_p0 = config.GetAliasedOutput(0, {});
    973   ASSERT_TRUE(alias_p0.has_value());
    974   EXPECT_EQ(*alias_p0, ShapeIndex({1}));
    975 
    976   auto alias_p1 = config.GetAliasedOutput(1, {});
    977   ASSERT_TRUE(alias_p1.has_value());
    978   EXPECT_EQ(*alias_p1, ShapeIndex({0}));
    979 }
    980 
    981 }  // namespace
    982 }  // namespace xla
    983