Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/client/client_library.h"
     21 #include "tensorflow/compiler/xla/client/computation.h"
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     24 #include "tensorflow/compiler/xla/client/local_client.h"
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/compiler/xla/service/platform_util.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/status_macros.h"
     29 #include "tensorflow/compiler/xla/statusor.h"
     30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     32 #include "tensorflow/compiler/xla/tests/test_macros.h"
     33 #include "tensorflow/compiler/xla/xla_data.pb.h"
     34 #include "tensorflow/core/lib/core/status_test_util.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/test.h"
     37 #include "tensorflow/core/platform/test_benchmark.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace se = ::perftools::gputools;
     41 
     42 namespace xla {
     43 namespace {
     44 
     45 class WhileTest : public ClientLibraryTestBase {};
     46 
     47 // Tests a while node when the result type T is S32.
     48 //
     49 // int32 result = 0;
     50 // while (result < 5) {
     51 //   result = result + 1;
     52 // }
     53 TEST_F(WhileTest, WhileWithScalarS32Result) {
     54   auto result_shape = ShapeUtil::MakeShape(S32, {});
     55 
     56   // Create a computation for the condition: repeat for 5 iterations.
     57   Computation condition;
     58   {
     59     ComputationBuilder builder(client_, "condition");
     60     auto prev = builder.Parameter(0, result_shape, "prev");
     61     builder.Gt(builder.ConstantR0<int32>(5), prev);
     62     condition = builder.Build().ConsumeValueOrDie();
     63   }
     64 
     65   // Create a computation for the body: add 1 to the result variable.
     66   Computation body;
     67   {
     68     ComputationBuilder builder(client_, "body");
     69     auto prev = builder.Parameter(0, result_shape, "prev");
     70     auto input = builder.ConstantR0<int32>(1);
     71     auto result = builder.Add(input, prev);
     72     body = builder.Build().ConsumeValueOrDie();
     73   }
     74 
     75   // Create a While node with computations for the condition and the body.
     76   ComputationBuilder builder(client_, TestName());
     77   auto init = builder.ConstantR0<int32>(0);
     78   auto result = builder.While(condition, body, init);
     79   auto shape = builder.GetShape(result).ConsumeValueOrDie();
     80 
     81   ComputeAndCompareR0<int32>(&builder, 5, {});
     82 }
     83 
     84 // Tests a while node when the result type T is S64.
     85 //
     86 // int32 result = 0;
     87 // while (result < 5) {
     88 //   result = result + 1;
     89 // }
     90 TEST_F(WhileTest, WhileWithScalarS64Result) {
     91   auto result_shape = ShapeUtil::MakeShape(S64, {});
     92 
     93   // Create a computation for the condition: repeat for 5 iterations.
     94   Computation condition;
     95   {
     96     ComputationBuilder builder(client_, "condition");
     97     auto prev = builder.Parameter(0, result_shape, "prev");
     98     builder.Gt(builder.ConstantR0<int64>(5), prev);
     99     condition = builder.Build().ConsumeValueOrDie();
    100   }
    101 
    102   // Create a computation for the body: add 1 to the result variable.
    103   Computation body;
    104   {
    105     ComputationBuilder builder(client_, "body");
    106     auto prev = builder.Parameter(0, result_shape, "prev");
    107     auto input = builder.ConstantR0<int64>(1);
    108     auto result = builder.Add(input, prev);
    109     body = builder.Build().ConsumeValueOrDie();
    110   }
    111 
    112   // Create a While node with computations for the condition and the body.
    113   ComputationBuilder builder(client_, TestName());
    114   auto init = builder.ConstantR0<int64>(0);
    115   auto result = builder.While(condition, body, init);
    116   auto shape = builder.GetShape(result).ConsumeValueOrDie();
    117 
    118   ComputeAndCompareR0<int64>(&builder, 5, {});
    119 }
    120 
    121 TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
    122   auto result_shape = ShapeUtil::MakeShape(S32, {});
    123   auto orig_shape = ShapeUtil::MakeShape(S32, {2});
    124 
    125   // Create a computation for the condition: repeat for 5 iterations.
    126   Computation condition;
    127   {
    128     ComputationBuilder builder(client_, "condition");
    129     auto prev = builder.Parameter(0, result_shape, "prev");
    130     builder.Gt(builder.ConstantR0<int32>(5), prev);
    131     condition = builder.Build().ConsumeValueOrDie();
    132   }
    133 
    134   // Create a computation for the body: add 1 to the result variable.
    135   Computation body;
    136   {
    137     ComputationBuilder builder(client_, "body");
    138     auto prev = builder.Parameter(0, result_shape, "prev");
    139     auto input = builder.ConstantR0<int32>(1);
    140     auto result = builder.Add(input, prev);
    141     body = builder.Build().ConsumeValueOrDie();
    142   }
    143 
    144   // Create a While node with computations for the condition and the body.
    145   ComputationBuilder builder(client_, TestName());
    146   auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
    147                              builder.ConstantR0<int32>(0),
    148                              CreateScalarAddComputation(S32, &builder), {0});
    149   auto result = builder.While(condition, body, init);
    150   auto shape = builder.GetShape(result).ConsumeValueOrDie();
    151 
    152   ComputeAndCompareR0<int32>(&builder, 5, {});
    153 }
    154 
    155 TEST_F(WhileTest, WhileWithPredicateResult) {
    156   auto result_shape = ShapeUtil::MakeShape(PRED, {});
    157 
    158   // Create a computation for the condition: run until condition is true.
    159   Computation condition;
    160   {
    161     ComputationBuilder builder(client_, "condition");
    162     auto prev = builder.Parameter(0, result_shape, "prev");
    163     builder.Ne(builder.ConstantR0<bool>(true), prev);
    164     condition = builder.Build().ConsumeValueOrDie();
    165   }
    166 
    167   // Create a computation for the body: or condition with true.
    168   Computation body;
    169   {
    170     ComputationBuilder builder(client_, "body");
    171     auto prev = builder.Parameter(0, result_shape, "prev");
    172     auto result = builder.Or(prev, builder.ConstantR0<bool>(true));
    173     body = builder.Build().ConsumeValueOrDie();
    174   }
    175 
    176   // Create a While node with computations for the condition and the body.
    177   ComputationBuilder builder(client_, TestName());
    178   auto init = builder.Ne(builder.ConstantR0<bool>(false),
    179                          builder.ConstantR0<bool>(true));
    180   auto result = builder.While(condition, body, init);
    181 
    182   ComputeAndCompareR0<bool>(&builder, true, {});
    183 }
    184 
    185 // Tests a while node when the result type T is a vector.
    186 //
    187 // All constants are chosen to produce exact results.
    188 // vector<float> result(0);
    189 // while (result.sum() < 15.5f) {
    190 //   result = result + vector<float>(0);
    191 // }
    192 // TODO(b/29185393): does not terminate on CPU.
    193 TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
    194   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
    195 
    196   // Create a computation for the reduction.
    197   Computation add;
    198   {
    199     ComputationBuilder builder(client_, "add");
    200     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    201     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    202     builder.Add(x, y);
    203     add = builder.Build().ConsumeValueOrDie();
    204   }
    205 
    206   // Create a computation for the condition.
    207   // Repeat until the sum of the result vector is less than 15.5f.
    208   Computation condition;
    209   {
    210     ComputationBuilder builder(client_, "condition");
    211     auto prev = builder.Parameter(0, result_shape, "prev");
    212     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
    213                               /*dimensions_to_reduce=*/{0});
    214     auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
    215     condition = builder.Build().ConsumeValueOrDie();
    216   }
    217 
    218   // Create a computation for the body.
    219   // Add a constant vector of 1.f to the result vector.
    220   Computation body;
    221   {
    222     ComputationBuilder builder(client_, "body");
    223     auto prev = builder.Parameter(0, result_shape, "prev");
    224     auto input = builder.ConstantR1<float>({});
    225     auto result = builder.Add(input, prev);
    226     body = builder.Build().ConsumeValueOrDie();
    227   }
    228 
    229   // Create a While node with computations for the condition and the body.
    230   ComputationBuilder builder(client_, "while");
    231   auto init = builder.ConstantR1<float>({});
    232   auto result = builder.While(condition, body, init);
    233   VLOG(2) << "while = " << ShapeUtil::HumanString(
    234                                *builder.GetShape(result).ConsumeValueOrDie());
    235 
    236   ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
    237 }
    238 
    239 // Tests a while node when the result type T is a vector.
    240 //
    241 // All constants are chosen to produce exact results.
    242 // vector<float> result(8, 0.0f);
    243 // while (result.sum() < 15.5f) {
    244 //   result = result + vector<float>(8, 0.125f);
    245 // }
    246 TEST_F(WhileTest, WhileWithVectorResult) {
    247   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
    248 
    249   // Create a computation for the reduction.
    250   Computation add;
    251   {
    252     ComputationBuilder builder(client_, "add");
    253     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    254     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    255     builder.Add(x, y);
    256     add = builder.Build().ConsumeValueOrDie();
    257   }
    258 
    259   // Create a computation for the condition.
    260   // Repeat until the sum of the result vector is less than 5.5f.
    261   Computation condition;
    262   {
    263     ComputationBuilder builder(client_, "condition");
    264     auto prev = builder.Parameter(0, result_shape, "prev");
    265     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
    266                               /*dimensions_to_reduce=*/{0});
    267     auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
    268     condition = builder.Build().ConsumeValueOrDie();
    269   }
    270 
    271   // Create a computation for the body.
    272   // Add a constant vector of 1.f to the result vector.
    273   Computation body;
    274   {
    275     ComputationBuilder builder(client_, "body");
    276     auto prev = builder.Parameter(0, result_shape, "prev");
    277     auto input = builder.ConstantR1<float>(8, 0.125f);
    278     auto result = builder.Add(input, prev);
    279     body = builder.Build().ConsumeValueOrDie();
    280   }
    281 
    282   // Create a While node with computations for the condition and the body.
    283   ComputationBuilder builder(client_, "while");
    284   auto init = builder.ConstantR1<float>(8, 0.f);
    285   auto result = builder.While(condition, body, init);
    286   VLOG(2) << "while = " << ShapeUtil::HumanString(
    287                                *builder.GetShape(result).ConsumeValueOrDie());
    288 
    289   // Individual elements with increase by 1/8 each time through the loop, so
    290   // the sum will increase by 1.0.  It will first be >15.5 when the elements
    291   // have all reached 2.0.
    292   std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
    293   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    294 }
    295 
    296 // Tests a while node when the result type is a vector which is part
    297 // of the result tuple.
    298 //
    299 // All constants are chosen to produce exact results.
    300 // vector<float> result(8, 0.0f);
    301 // while (result.sum() < 15.5f) {
    302 //   result = result + vector<float>(8, 0.125f);
    303 // }
    304 // tuple = tuple { while }
    305 TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
    306   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
    307 
    308   // Create a computation for the reduction.
    309   Computation add;
    310   {
    311     ComputationBuilder builder(client_, "add");
    312     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    313     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    314     builder.Add(x, y);
    315     add = builder.Build().ConsumeValueOrDie();
    316   }
    317 
    318   // Create a computation for the condition.
    319   // Repeat until the sum of the result vector is less than 5.5f.
    320   Computation condition;
    321   {
    322     ComputationBuilder builder(client_, "condition");
    323     auto prev = builder.Parameter(0, result_shape, "prev");
    324     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
    325                               /*dimensions_to_reduce=*/{0});
    326     auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
    327     condition = builder.Build().ConsumeValueOrDie();
    328   }
    329 
    330   // Create a computation for the body.
    331   // Add a constant vector of 1.f to the result vector.
    332   Computation body;
    333   {
    334     ComputationBuilder builder(client_, "body");
    335     auto prev = builder.Parameter(0, result_shape, "prev");
    336     auto input = builder.ConstantR1<float>(8, 0.125f);
    337     auto result = builder.Add(input, prev);
    338     body = builder.Build().ConsumeValueOrDie();
    339   }
    340 
    341   // Create a While node with computations for the condition and the body.
    342   ComputationBuilder builder(client_, "while");
    343   auto init = builder.ConstantR1<float>(8, 0.f);
    344   auto result = builder.While(condition, body, init);
    345   VLOG(2) << "while = "
    346           << ShapeUtil::HumanString(
    347                  *builder.GetShape(result).ConsumeValueOrDie());
    348   builder.Tuple({result});
    349 
    350   // Individual elements with increase by 1/8 each time through the loop, so
    351   // the sum will increase by 1.0.  It will first be >15.5 when the elements
    352   // have all reached 2.0.
    353   auto expected_data =
    354       Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
    355   auto expected = Literal::MakeTuple({expected_data.get()});
    356   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
    357   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
    358 }
    359 
    360 TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
    361   std::vector<Shape> shape_elements = {
    362       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
    363       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
    364   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    365 
    366   // Create a computation for the condition.
    367   // Repeat for N iterations.
    368   const int N = 2;
    369   Computation condition;
    370   {
    371     ComputationBuilder builder(client_, "condition");
    372     auto prev = builder.Parameter(0, result_shape, "prev");
    373     auto iteration = builder.GetTupleElement(prev, 0);
    374     builder.Gt(builder.ConstantR0<int32>(N), iteration);
    375     condition = builder.Build().ConsumeValueOrDie();
    376   }
    377 
    378   // Create a computation for the body.
    379   // Add 1 to the iteration variable and permute the weights.
    380   Computation body;
    381   {
    382     ComputationBuilder builder(client_, "body");
    383     auto prev = builder.Parameter(0, result_shape, "prev");
    384     auto iteration = builder.GetTupleElement(prev, 0);
    385     auto w1 = builder.GetTupleElement(prev, 1);
    386     auto w2 = builder.GetTupleElement(prev, 2);
    387     auto w3 = builder.GetTupleElement(prev, 3);
    388     auto result = builder.Tuple(
    389         {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
    390     body = builder.Build().ConsumeValueOrDie();
    391   }
    392 
    393   // Create a While node with computations for the condition and the body.
    394   ComputationBuilder builder(client_, "while");
    395   auto init = builder.Tuple(
    396       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
    397        builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
    398   auto result = builder.While(condition, body, init);
    399   VLOG(2) << "result = "
    400           << ShapeUtil::HumanString(
    401                  *builder.GetShape(result).ConsumeValueOrDie());
    402 
    403   auto expected_counter = Literal::CreateR0<int32>(N);
    404   auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
    405   auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f});
    406   auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f});
    407   auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(),
    408                                       expected_w3.get(), expected_w1.get()});
    409   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
    410   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
    411 }
    412 
    413 TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
    414   std::vector<Shape> shape_elements = {
    415       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
    416       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
    417   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    418 
    419   // Create a computation for the condition.
    420   // Repeat for N iterations.
    421   const int N = 2;
    422   Computation condition;
    423   {
    424     ComputationBuilder builder(client_, "condition");
    425     auto prev = builder.Parameter(0, result_shape, "prev");
    426     auto iteration = builder.GetTupleElement(prev, 0);
    427     builder.Gt(builder.ConstantR0<int32>(N), iteration);
    428     condition = builder.Build().ConsumeValueOrDie();
    429   }
    430 
    431   // Create a computation for the body.
    432   // Add 1 to the iteration variable permute the weights.
    433   Computation body;
    434   {
    435     ComputationBuilder builder(client_, "body");
    436     auto prev = builder.Parameter(0, result_shape, "prev");
    437     auto iteration = builder.GetTupleElement(prev, 0);
    438     auto w1 = builder.GetTupleElement(prev, 1);
    439     auto w2 = builder.GetTupleElement(prev, 2);
    440     auto w3 = builder.GetTupleElement(prev, 3);
    441     auto result = builder.Tuple(
    442         {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
    443     body = builder.Build().ConsumeValueOrDie();
    444   }
    445 
    446   // Create a While node with computations for the condition and the body.
    447   ComputationBuilder builder(client_, "while");
    448   auto init = builder.Tuple(
    449       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
    450        builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
    451   auto xla_while = builder.While(condition, body, init);
    452 
    453   auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1),
    454                            builder.GetTupleElement(xla_while, 2));
    455   auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
    456   VLOG(2) << "result = "
    457           << ShapeUtil::HumanString(
    458                  *builder.GetShape(result).ConsumeValueOrDie());
    459   std::vector<float> expected = {6.f, 6.f, 6.f};
    460   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    461 }
    462 
    463 // Tests a while node when the result type T is a Tuple.
    464 //
    465 // tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
    466 // while (get<0>(result) < 5) {
    467 //   get<0>(result) = get<0>(result) + 1;
    468 //   get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
    469 // }
    470 TEST_F(WhileTest, WhileWithTupleResult) {
    471   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    472                                        ShapeUtil::MakeShape(F32, {10})};
    473   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    474 
    475   // Create a computation for the condition.
    476   // Repeat for 5 iterations.
    477   Computation condition;
    478   {
    479     ComputationBuilder builder(client_, "condition");
    480     auto prev = builder.Parameter(0, result_shape, "prev");
    481     auto iteration = builder.GetTupleElement(prev, 0);
    482     builder.Gt(builder.ConstantR0<int32>(5), iteration);
    483     condition = builder.Build().ConsumeValueOrDie();
    484   }
    485 
    486   // Create a computation for the body.
    487   // Add 1 to the iteration variable and add a constant vector of 1.0f to
    488   // the weight variable, both of which are tuple elements.
    489   Computation body;
    490   {
    491     ComputationBuilder builder(client_, "body");
    492     auto prev = builder.Parameter(0, result_shape, "prev");
    493     auto iteration = builder.GetTupleElement(prev, 0);
    494     auto weights = builder.GetTupleElement(prev, 1);
    495     auto input = builder.ConstantR1<float>(10, 1.f);
    496     auto new_weights = builder.Add(weights, input);
    497     auto result = builder.Tuple(
    498         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    499     body = builder.Build().ConsumeValueOrDie();
    500   }
    501 
    502   // Create a While node with computations for the condition and the body.
    503   ComputationBuilder builder(client_, "while");
    504   auto init = builder.Tuple(
    505       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
    506   auto result = builder.While(condition, body, init);
    507   VLOG(2) << "while = " << ShapeUtil::HumanString(
    508                                *builder.GetShape(result).ConsumeValueOrDie());
    509 
    510   auto expected_counter = Literal::CreateR0<int32>(5);
    511   auto expected_data = Literal::CreateR1<float>(
    512       {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
    513   auto expected =
    514       Literal::MakeTuple({expected_counter.get(), expected_data.get()});
    515   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
    516   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
    517 }
    518 
    519 TEST_F(WhileTest, WhileWithPredicateTupleResult) {
    520   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    521                                        ShapeUtil::MakeShape(PRED, {})};
    522   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    523 
    524   // Create a computation for the condition.
    525   // Repeat for 5 iterations.
    526   Computation condition;
    527   {
    528     ComputationBuilder builder(client_, "condition");
    529     auto prev = builder.Parameter(0, result_shape, "prev");
    530     auto iteration = builder.GetTupleElement(prev, 0);
    531     builder.Gt(builder.ConstantR0<int32>(5), iteration);
    532     condition = builder.Build().ConsumeValueOrDie();
    533   }
    534 
    535   // Create a computation for the body.
    536   // Add 1 to the iteration variable and or the predicate with true
    537   Computation body;
    538   {
    539     ComputationBuilder builder(client_, "body");
    540     auto prev = builder.Parameter(0, result_shape, "prev");
    541     auto iteration = builder.GetTupleElement(prev, 0);
    542     auto pred = builder.GetTupleElement(prev, 1);
    543     auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
    544     auto result = builder.Tuple(
    545         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
    546     body = builder.Build().ConsumeValueOrDie();
    547   }
    548 
    549   // Create a While node with computations for the condition and the body.
    550   ComputationBuilder builder(client_, "while");
    551   auto init = builder.Tuple({builder.ConstantR0<int32>(0),
    552                              builder.Ne(builder.ConstantR0<bool>(false),
    553                                         builder.ConstantR0<bool>(true))});
    554   auto result = builder.While(condition, body, init);
    555   VLOG(2) << "while = "
    556           << ShapeUtil::HumanString(
    557                  *builder.GetShape(result).ConsumeValueOrDie());
    558 
    559   auto expected_counter = Literal::CreateR0<int32>(5);
    560   auto expected_predicate = Literal::CreateR0<bool>(true);
    561   auto expected =
    562       Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
    563   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
    564 }
    565 
    566 TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
    567   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    568                                        ShapeUtil::MakeShape(S32, {})};
    569   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    570 
    571   // Create a computation for the condition.
    572   // Repeat for 5 iterations.
    573   Computation condition;
    574   {
    575     ComputationBuilder builder(client_, "condition");
    576     auto prev = builder.Parameter(0, result_shape, "prev");
    577     auto iteration = builder.GetTupleElement(prev, 0);
    578     builder.Gt(builder.ConstantR0<int32>(5), iteration);
    579     condition = builder.Build().ConsumeValueOrDie();
    580   }
    581 
    582   // Create a computation for the body.
    583   // Add 1 to the iteration variable and set the other tuple element to a
    584   // constant.
    585   Computation body;
    586   {
    587     ComputationBuilder builder(client_, "body");
    588     auto prev = builder.Parameter(0, result_shape, "prev");
    589     auto iteration = builder.GetTupleElement(prev, 0);
    590     auto result =
    591         builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
    592                        builder.ConstantR0<int32>(7)});
    593     body = builder.Build().ConsumeValueOrDie();
    594   }
    595 
    596   // Create a While node with computations for the condition and the body.
    597   ComputationBuilder builder(client_, "while");
    598   auto init = builder.Tuple(
    599       {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
    600   auto result = builder.While(condition, body, init);
    601   VLOG(2) << "while = "
    602           << ShapeUtil::HumanString(
    603                  *builder.GetShape(result).ConsumeValueOrDie());
    604 
    605   auto expected_counter = Literal::CreateR0<int32>(5);
    606   auto expected_data = Literal::CreateR0<int32>(7);
    607   auto expected =
    608       Literal::MakeTuple({expected_counter.get(), expected_data.get()});
    609   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
    610   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
    611 }
    612 
    613 // Tests two while nodes when the result type T is a Tuple and the second
    614 // while node uses the result of the first while node which is used in two
    615 // nodes.
    616 // tuple<int32, vector<float>> w0(0, vector<float>(10, 0.0f));
    617 // w0 = while (get<0>(w0) < c1) {
    618 //        get<0>(w0) = get<0>(w0) + 1;
    619 //        get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
    620 //      }
    621 // tuple<int32, vector<float>> w1(get<0>(w0), get<1>(w0));
    622 // w1 = while (get<0>(w1) < c2) {
    623 //        get<0>(w1) = get<0>(w1) + 1;
    624 //        get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
    625 //      }
    626 // result = get<1>(w0) + get<1>(w1)
    627 TEST_F(WhileTest, TwoWhileWithTupleResult) {
    628   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    629                                        ShapeUtil::MakeShape(F32, {10})};
    630   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    631 
    632   // Create a computation for the condition.
    633   // Repeat for 5 iterations.
    634   Computation condition;
    635   const int c1 = 5;
    636   {
    637     ComputationBuilder builder(client_, "condition");
    638     auto prev = builder.Parameter(0, result_shape, "prev");
    639     auto iteration = builder.GetTupleElement(prev, 0);
    640     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    641     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
    642   }
    643 
    644   Computation condition2;
    645   const int c2 = 7;
    646   {
    647     ComputationBuilder builder(client_, "condition2");
    648     auto prev = builder.Parameter(0, result_shape, "prev");
    649     auto iteration = builder.GetTupleElement(prev, 0);
    650     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    651     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
    652   }
    653 
    654   // Create a computation for the body.
    655   // Add 1 to the iteration variable and add a constant vector of 1.0f to
    656   // the weight variable, both of which are tuple elements.
    657   Computation body;
    658   {
    659     ComputationBuilder builder(client_, "body");
    660     auto prev = builder.Parameter(0, result_shape, "prev");
    661     auto iteration = builder.GetTupleElement(prev, 0);
    662     auto weights = builder.GetTupleElement(prev, 1);
    663     auto input = builder.ConstantR1<float>(10, 1.f);
    664     auto new_weights = builder.Add(weights, input);
    665     auto result = builder.Tuple(
    666         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    667     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
    668   }
    669 
    670   Computation body2;
    671   {
    672     ComputationBuilder builder(client_, "body");
    673     auto prev = builder.Parameter(0, result_shape, "prev");
    674     auto iteration = builder.GetTupleElement(prev, 0);
    675     auto weights = builder.GetTupleElement(prev, 1);
    676     auto input = builder.ConstantR1<float>(10, 1.f);
    677     auto new_weights = builder.Add(weights, input);
    678     auto result = builder.Tuple(
    679         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    680     TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
    681   }
    682 
    683   // Create a While node with computations for the condition and the body.
    684   ComputationBuilder builder(client_, "while");
    685   auto init = builder.Tuple(
    686       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
    687   auto while1 = builder.While(condition, body, init);
    688 
    689   auto while2 = builder.While(condition2, body2, while1);
    690 
    691   auto while_result1 = builder.GetTupleElement(while1, 1);
    692   auto while_result2 = builder.GetTupleElement(while2, 1);
    693   VLOG(2) << "while_result2 = "
    694           << ShapeUtil::HumanString(
    695                  *builder.GetShape(while_result2).ConsumeValueOrDie());
    696   auto result = builder.Add(while_result1, while_result2);
    697   VLOG(2) << "result = "
    698           << ShapeUtil::HumanString(
    699                  *builder.GetShape(result).ConsumeValueOrDie());
    700   const float sum = c1 + c2;
    701   std::vector<float> expected(10, sum);
    702   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    703 }
    704 
    705 // Test while nodes that share the while body computation.
    706 TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
    707   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    708                                        ShapeUtil::MakeShape(F32, {10})};
    709   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    710 
    711   // Create a computation for the condition.
    712   // Repeat for 5 iterations.
    713   Computation condition;
    714   const int c1 = 5;
    715   {
    716     ComputationBuilder builder(client_, "condition");
    717     auto prev = builder.Parameter(0, result_shape, "prev");
    718     auto iteration = builder.GetTupleElement(prev, 0);
    719     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    720     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
    721   }
    722 
    723   Computation condition2;
    724   const int c2 = 7;
    725   {
    726     ComputationBuilder builder(client_, "condition2");
    727     auto prev = builder.Parameter(0, result_shape, "prev");
    728     auto iteration = builder.GetTupleElement(prev, 0);
    729     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    730     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
    731   }
    732 
    733   // Create a computation for the body.
    734   // Add 1 to the iteration variable and add a constant vector of 1.0f to
    735   // the weight variable, both of which are tuple elements.
    736   Computation body;
    737   {
    738     ComputationBuilder builder(client_, "body");
    739     auto prev = builder.Parameter(0, result_shape, "prev");
    740     auto iteration = builder.GetTupleElement(prev, 0);
    741     auto weights = builder.GetTupleElement(prev, 1);
    742     auto input = builder.ConstantR1<float>(10, 1.f);
    743     auto new_weights = builder.Add(weights, input);
    744     auto result = builder.Tuple(
    745         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    746     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
    747   }
    748 
    749   // Create a While node with computations for the condition and the body.
    750   ComputationBuilder builder(client_, "while");
    751   auto init = builder.Tuple(
    752       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
    753   auto while1 = builder.While(condition, body, init);
    754 
    755   auto while2 = builder.While(condition2, body, while1);
    756 
    757   auto while_result1 = builder.GetTupleElement(while1, 1);
    758   auto while_result2 = builder.GetTupleElement(while2, 1);
    759   VLOG(2) << "while_result2 = "
    760           << ShapeUtil::HumanString(
    761                  *builder.GetShape(while_result2).ConsumeValueOrDie());
    762   auto result = builder.Add(while_result1, while_result2);
    763   VLOG(2) << "result = "
    764           << ShapeUtil::HumanString(
    765                  *builder.GetShape(result).ConsumeValueOrDie());
    766   const float sum = c1 + c2;
    767   std::vector<float> expected(10, sum);
    768   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    769 }
    770 
    771 // Test while nodes that share the while body computation.
    772 // TODO(b/37245345): Fails on GPU backend.
    773 TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
    774   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    775                                        ShapeUtil::MakeShape(F32, {10})};
    776   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    777 
    778   // Create a computation for the condition.
    779   // Repeat for 5 iterations.
    780   Computation condition;
    781   const int c1 = 5;
    782   {
    783     ComputationBuilder builder(client_, "condition");
    784     auto prev = builder.Parameter(0, result_shape, "prev");
    785     auto iteration = builder.GetTupleElement(prev, 0);
    786     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
    787     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
    788   }
    789 
    790   Computation condition2;
    791   const int c2 = 7;
    792   {
    793     ComputationBuilder builder(client_, "condition2");
    794     auto prev = builder.Parameter(0, result_shape, "prev");
    795     auto iteration = builder.GetTupleElement(prev, 0);
    796     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
    797     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
    798   }
    799 
    800   // Create a computation for the body.
    801   // Add 1 to the iteration variable and add a constant vector of 1.0f to
    802   // the weight variable, both of which are tuple elements.
    803   Computation body;
    804   {
    805     ComputationBuilder builder(client_, "body");
    806     auto prev = builder.Parameter(0, result_shape, "prev");
    807     auto iteration = builder.GetTupleElement(prev, 0);
    808     auto weights = builder.GetTupleElement(prev, 1);
    809     auto input = builder.ConstantR1<float>(10, 1.f);
    810     auto new_weights = builder.Add(weights, input);
    811     auto result = builder.Tuple(
    812         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
    813     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
    814   }
    815 
    816   // Create a While node with computations for the condition and the body.
    817   ComputationBuilder builder(client_, "while");
    818   auto init = builder.Tuple(
    819       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
    820   auto while1 = builder.While(condition, body, init);
    821   auto while2 = builder.While(condition2, body, init);
    822 
    823   auto while_result1 = builder.GetTupleElement(while1, 1);
    824   auto while_result2 = builder.GetTupleElement(while2, 1);
    825   VLOG(2) << "while_result2 = "
    826           << ShapeUtil::HumanString(
    827                  *builder.GetShape(while_result2).ConsumeValueOrDie());
    828   auto result = builder.Add(while_result1, while_result2);
    829   VLOG(2) << "result = "
    830           << ShapeUtil::HumanString(
    831                  *builder.GetShape(result).ConsumeValueOrDie());
    832   const float sum = c1 + c2;
    833   std::vector<float> expected(10, sum);
    834   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    835 }
    836 
    837 // WhileTest that uses DynamicUpdateSlice instruction in body computation.
    838 // Loop state tuple element 1 has as its single user operand(0) of
    839 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
    840 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
    841   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
    842                                        ShapeUtil::MakeShape(F32, {10})};
    843   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
    844 
    845   // Create a computation for the condition.
    846   // Repeat for 5 iterations.
    847   Computation condition;
    848   {
    849     ComputationBuilder builder(client_, "condition");
    850     auto prev = builder.Parameter(0, result_shape, "prev");
    851     auto iteration = builder.GetTupleElement(prev, 0);
    852     builder.Gt(builder.ConstantR0<int32>(5), iteration);
    853     condition = builder.Build().ConsumeValueOrDie();
    854   }
    855 
    856   // Create a computation for the body.
    857   // Add 1 to the iteration variable and add a constant vector of 1.0f to
    858   // the weight variable, both of which are tuple elements.
    859   Computation body;
    860   {
    861     ComputationBuilder builder(client_, "body");
    862     auto prev = builder.Parameter(0, result_shape, "prev");
    863     // TupleElement 0
    864     auto iteration = builder.GetTupleElement(prev, 0);
    865     auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
    866     // TupleElement 1
    867     auto input = builder.GetTupleElement(prev, 1);
    868     // Update.
    869     auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32);
    870     // Starts = iteration * 2;
    871     auto starts = builder.Reshape(
    872         builder.Mul(iteration, builder.ConstantR0<int32>(2)), {1});
    873     // UpdateSlice.
    874     auto out1 = builder.DynamicUpdateSlice(input, update, starts);
    875 
    876     auto result = builder.Tuple({out0, out1});
    877     body = builder.Build().ConsumeValueOrDie();
    878   }
    879 
    880   // Create a While node with computations for the condition and the body.
    881   ComputationBuilder builder(client_, "while");
    882   auto init = builder.Tuple(
    883       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
    884   auto result = builder.While(condition, body, init);
    885   VLOG(2) << "while = "
    886           << ShapeUtil::HumanString(
    887                  *builder.GetShape(result).ConsumeValueOrDie());
    888 
    889   auto expected_counter = Literal::CreateR0<int32>(5);
    890   auto expected_data = Literal::CreateR1<float>(
    891       {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
    892   auto expected =
    893       Literal::MakeTuple({expected_counter.get(), expected_data.get()});
    894   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
    895   ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
    896 }
    897 
    898 // Tests a while node when the result type T is a vector of S32.
    899 //
    900 // int32 result = (0, 0, 0, 0, 0, 0);
    901 // while (result[0] < count) {
    902 //   result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
    903 // }
    904 //
    905 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
    906 // pair:
    907 //   ((iteration, (random vector))).
    908 //
    909 // Note: this test currently only tests generating random values within a loop.
    910 // Per backend the values generated can be different as the different backends
    911 // use different random number generators.
    912 // TODO(b/32240857): Extend test to verify outputs.
    913 TEST_F(WhileTest, WhileWithPrngScalarResult) {
    914   auto v6s32 = ShapeUtil::MakeShape(S32, {6});
    915 
    916   // Create a computation for the condition: repeat for count iterations.
    917   auto build_condition = [this, v6s32](int count) {
    918     ComputationBuilder builder(client_, TestName());
    919     auto prev = builder.Reshape(
    920         builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
    921           {});
    922     builder.Gt(builder.ConstantR0<int32>(count), prev);
    923     return builder.Build().ConsumeValueOrDie();
    924   };
    925 
    926   // Create a computation for the body: add 1 to the result variable.
    927   Computation body;
    928   {
    929     ComputationBuilder builder(client_, "body");
    930     auto prev = builder.Parameter(0, v6s32, "prev");
    931     auto inc = builder.ConcatInDim(
    932         {builder.ConstantR1<int32>({1}),
    933          builder.RngUniform(builder.ConstantR0<int32>(0),
    934                             builder.ConstantR0<int32>(100),
    935                             ShapeUtil::MakeShape(S32, {5}))},
    936         0);
    937     auto result = builder.Add(inc, prev);
    938     body = builder.Build().ConsumeValueOrDie();
    939   }
    940 
    941   // Create a While node with computations for the condition and the body.
    942   auto while_loop = [this, &body, build_condition](int count) {
    943     ComputationBuilder builder(client_, TestName());
    944     auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
    945     auto result = builder.While(build_condition(count), body, init);
    946     auto shape = builder.GetShape(result).ConsumeValueOrDie();
    947     return builder.Build();
    948   };
    949 
    950   for (int i = 1; i < 4; ++i) {
    951     TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i));
    952 
    953     ExecutionOptions execution_options = execution_options_;
    954     execution_options.set_seed(65);
    955     TF_ASSERT_OK_AND_ASSIGN(
    956         auto result,
    957         client_->ExecuteAndTransfer(computation, {}, &execution_options));
    958   }
    959 }
    960 
    961 TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
    962   auto element_shape = ShapeUtil::MakeShape(F32, {2});
    963 
    964   ComputationBuilder outer(client_, "outer");
    965   auto p = outer.Parameter(0, element_shape, "param");
    966   auto t = outer.Tuple({p, outer.ConstantR1<float>({1, 1})});
    967 
    968   TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Shape> tuple_shape,
    969                           outer.GetShape(t));
    970 
    971   ComputationBuilder cond(client_, "cond");
    972   auto cond_t = cond.Parameter(0, *tuple_shape, "t");
    973   TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
    974                            cond.ConstantR1<float>({42, 42})),
    975                    &cond)
    976                    .status());
    977 
    978   ComputationBuilder body(client_, "body");
    979   auto body_t = body.Parameter(0, *tuple_shape, "t");
    980   auto e = body.GetTupleElement(body_t, 1);
    981   body.Tuple({e, e});
    982 
    983   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
    984   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
    985   outer.While(cond_computation, body_computation, t);
    986 
    987   auto expected_element = Literal::CreateR1<float>({1, 1});
    988   auto expected =
    989       Literal::MakeTuple({expected_element.get(), expected_element.get()});
    990   TF_ASSERT_OK_AND_ASSIGN(
    991       std::unique_ptr<GlobalData> parameter_data,
    992       client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
    993   ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
    994                          ErrorSpec(1e-6));
    995 }
    996 
    997 TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
    998   auto element_shape = ShapeUtil::MakeShape(F32, {2});
    999 
   1000   ComputationBuilder outer(client_, "outer");
   1001   auto p = outer.Parameter(0, element_shape, "param");
   1002 
   1003   ComputationBuilder cond(client_, "cond");
   1004   auto cond_t = cond.Parameter(0, element_shape, "t");
   1005   TF_ASSERT_OK(
   1006       Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status());
   1007 
   1008   ComputationBuilder body(client_, "body");
   1009   auto body_t = body.Parameter(0, element_shape, "t");
   1010   auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2});
   1011 
   1012   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
   1013   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
   1014   outer.While(cond_computation, body_computation, p);
   1015 
   1016   TF_ASSERT_OK_AND_ASSIGN(
   1017       std::unique_ptr<GlobalData> parameter_data,
   1018       client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
   1019   ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
   1020                              ErrorSpec(1e-6));
   1021 }
   1022 
   1023 TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
   1024   auto element_shape = ShapeUtil::MakeShape(F32, {});
   1025 
   1026   ComputationBuilder outer(client_, "outer");
   1027   auto p = outer.Parameter(0, element_shape, "param");
   1028 
   1029   ComputationBuilder cond(client_, "cond");
   1030   auto cond_t = cond.Parameter(0, element_shape, "t");
   1031   cond.Eq(cond_t, cond.ConstantR0<float>(42));
   1032 
   1033   ComputationBuilder body(client_, "body");
   1034   auto body_t = body.Parameter(0, element_shape, "t");
   1035   auto tuple =
   1036       body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))});
   1037   auto e = body.GetTupleElement(tuple, 1);
   1038 
   1039   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
   1040   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
   1041   outer.While(cond_computation, body_computation, p);
   1042 
   1043   TF_ASSERT_OK_AND_ASSIGN(
   1044       std::unique_ptr<GlobalData> parameter_data,
   1045       client_->TransferToServer(*Literal::CreateR0<float>(42)));
   1046   ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
   1047                              ErrorSpec(1e-6));
   1048 }
   1049 
   1050 // Tests loop where the init value comes from two sources (constant and
   1051 // parameter).
   1052 //
   1053 // int32 result = (0, 1);
   1054 // while (result[0] + result[1] < 30) {
   1055 //   result[0] = result[0] + 1;
   1056 //   result[1] = result[1] + 1;
   1057 // }
   1058 TEST_F(WhileTest, WhileWithMixedTupleElements) {
   1059   auto result_shape = ShapeUtil::MakeTupleShape(
   1060       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
   1061 
   1062   ComputationBuilder outer(client_, "outer");
   1063   auto p =
   1064       outer.Tuple({outer.ConstantR0<int32>(0),
   1065                    outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")});
   1066 
   1067   ComputationBuilder cond(client_, "cond");
   1068   auto params = cond.Parameter(0, result_shape, "prev");
   1069   auto cond_t = cond.Add(cond.GetTupleElement(params, 1),
   1070                          cond.GetTupleElement(params, 0));
   1071   cond.Lt(cond_t, cond.ConstantR0<int32>(30));
   1072 
   1073   ComputationBuilder body(client_, "body");
   1074   auto body_t = body.Parameter(0, result_shape, "t");
   1075 
   1076   auto tuple = body.Tuple(
   1077       {body.Add(body.GetTupleElement(params, 0), body.ConstantR0<int32>(1)),
   1078        body.Add(body.GetTupleElement(params, 1), body.ConstantR0<int32>(1))});
   1079 
   1080   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
   1081   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
   1082   outer.While(cond_computation, body_computation, p);
   1083 
   1084   TF_ASSERT_OK_AND_ASSIGN(
   1085       std::unique_ptr<GlobalData> parameter_data,
   1086       client_->TransferToServer(*Literal::CreateR0<int32>(1)));
   1087 
   1088   auto add1 = Literal::CreateR0<int32>(15);
   1089   auto add2 = Literal::CreateR0<int32>(16);
   1090   auto expected = Literal::MakeTuple({add1.get(), add2.get()});
   1091   ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
   1092                          ErrorSpec(1e-6));
   1093 }
   1094 
   1095 // Tests nested while loops.
   1096 //
   1097 // int32 result = 0;
   1098 // while (result < 30) {
   1099 //   int i = 0;
   1100 //   while (i < 7) {
   1101 //     result = result + 2;
   1102 //     i = i + 1;
   1103 //   }
   1104 // }
   1105 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
   1106   auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
   1107   auto inner_result_shape = ShapeUtil::MakeTupleShape(
   1108       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
   1109 
   1110   Computation inner_condition;
   1111   {
   1112     ComputationBuilder builder(client_, "inner_condition");
   1113     auto params = builder.Parameter(0, inner_result_shape, "prev");
   1114     auto i = builder.GetTupleElement(params, 0);
   1115     builder.Lt(i, builder.ConstantR0<int32>(7));
   1116     inner_condition = builder.Build().ConsumeValueOrDie();
   1117   }
   1118 
   1119   // Creates a computation for the outer loop condition:
   1120   // repeat while result < 30.
   1121   Computation outer_condition;
   1122   {
   1123     ComputationBuilder builder(client_, "outer_condition");
   1124     auto prev = builder.Parameter(0, outer_result_shape, "prev");
   1125     builder.Lt(prev, builder.ConstantR0<int32>(30));
   1126     outer_condition = builder.Build().ConsumeValueOrDie();
   1127   }
   1128 
   1129   // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
   1130   // `result`.
   1131   Computation inner_body;
   1132   {
   1133     ComputationBuilder builder(client_, "inner_body");
   1134     auto params = builder.Parameter(0, inner_result_shape, "prev");
   1135     auto i = builder.GetTupleElement(params, 0);
   1136     auto result = builder.GetTupleElement(params, 1);
   1137     i = builder.Add(builder.ConstantR0<int32>(1), i);
   1138     result = builder.Add(builder.ConstantR0<int32>(2), result);
   1139     auto output = builder.Tuple({i, result});
   1140     inner_body = builder.Build().ConsumeValueOrDie();
   1141   }
   1142 
   1143   // Creates a computation for the outer loop: run the inner loop with i = 0.
   1144   Computation outer_body;
   1145   {
   1146     ComputationBuilder builder(client_, "outer_body");
   1147     auto prev = builder.Parameter(0, outer_result_shape, "prev");
   1148     auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
   1149     auto result = builder.While(inner_condition, inner_body, init);
   1150     auto output = builder.GetTupleElement(result, 1);
   1151     outer_body = builder.Build().ConsumeValueOrDie();
   1152   }
   1153 
   1154   // Create a While node with computations for the condition and the body.
   1155   ComputationBuilder builder(client_, TestName());
   1156   auto init = builder.ConstantR0<int32>(0);
   1157   auto result = builder.While(outer_condition, outer_body, init);
   1158   auto shape = builder.GetShape(result).ConsumeValueOrDie();
   1159 
   1160   ComputeAndCompareR0<int32>(&builder, 42, {});
   1161 }
   1162 
   1163 // Tests a while node when the result type T is S32.
   1164 // f = lambda result: tuple({result < 5})
   1165 // int32 result = 0;
   1166 // while (f(result).get<0>()) {
   1167 //   result = result + 1;
   1168 // }
   1169 TEST_F(WhileTest, WhileWithCallInsideCondition) {
   1170   auto result_shape = ShapeUtil::MakeShape(S32, {});
   1171 
   1172   // Create a computation for the condition: repeat for 5 iterations.
   1173   Computation condition_callee;
   1174   {
   1175     ComputationBuilder builder(client_, "condition_callee");
   1176     auto prev = builder.Parameter(0, result_shape, "prev");
   1177     builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
   1178 
   1179     condition_callee = builder.Build().ConsumeValueOrDie();
   1180   }
   1181 
   1182   Computation condition;
   1183   {
   1184     ComputationBuilder builder(client_, "condition");
   1185     auto prev = builder.Parameter(0, result_shape, "prev");
   1186     auto result = builder.Call(condition_callee, {prev});
   1187     builder.GetTupleElement(result, 0);
   1188     condition = builder.Build().ConsumeValueOrDie();
   1189   }
   1190 
   1191   // Create a computation for the body: add 1 to the result variable.
   1192   Computation body;
   1193   {
   1194     ComputationBuilder builder(client_, "body");
   1195     auto prev = builder.Parameter(0, result_shape, "prev");
   1196     auto input = builder.ConstantR0<int32>(1);
   1197     auto result = builder.Add(input, prev);
   1198     body = builder.Build().ConsumeValueOrDie();
   1199   }
   1200 
   1201   // Create a While node with computations for the condition and the body.
   1202   ComputationBuilder builder(client_, TestName());
   1203   auto init = builder.ConstantR0<int32>(0);
   1204   auto result = builder.While(condition, body, init);
   1205   auto shape = builder.GetShape(result).ConsumeValueOrDie();
   1206 
   1207   ComputeAndCompareR0<int32>(&builder, 5, {});
   1208 }
   1209 
   1210 TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
   1211   auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
   1212   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
   1213   auto while_shape = ShapeUtil::MakeTupleShape(
   1214       {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
   1215 
   1216   // Create a computation for the condition: repeat for 5 iterations.
   1217   Computation condition;
   1218   {
   1219     ComputationBuilder builder(client_, "condition");
   1220     auto state = builder.Parameter(0, while_shape, "state");
   1221     builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
   1222     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
   1223   }
   1224 
   1225   Computation body;
   1226   {
   1227     ComputationBuilder builder(client_, "body");
   1228     auto state = builder.Parameter(0, while_shape, "state");
   1229     auto indvar = builder.GetTupleElement(state, 0);
   1230     auto input_0 = builder.GetTupleElement(state, 1);
   1231     auto input_1 = builder.GetTupleElement(state, 2);
   1232     auto output = builder.Tanh(builder.Dot(input_0, input_1));
   1233     auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
   1234     auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
   1235     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
   1236   }
   1237 
   1238   ComputationBuilder builder(client_, TestName());
   1239   auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
   1240   auto init = builder.Tuple(
   1241       {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
   1242   auto while_instruction = builder.While(condition, body, init);
   1243   builder.GetTupleElement(while_instruction, 3);
   1244 
   1245   TF_ASSERT_OK_AND_ASSIGN(auto param_value,
   1246                           client_->TransferToServer(*Literal::CreateR2<float>(
   1247                               {{1.0, 2.0}, {-1.0, -2.0}})));
   1248 
   1249   ComputeAndCompareR2<float>(
   1250       &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
   1251       {param_value.get()}, ErrorSpec(4e-5));
   1252 }
   1253 
   1254 void BM_WhileLoop(int num_iters) {
   1255   // Benchmark a simple kernel to measure while loop overheads.
   1256   tensorflow::testing::StopTiming();
   1257 
   1258   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
   1259   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
   1260   StreamExecutorMemoryAllocator allocator(platform, executors);
   1261   LocalClient* client =
   1262       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
   1263 
   1264   const int64 seq_len = 100;
   1265   Shape loop_state_shape = ShapeUtil::MakeTupleShape(
   1266       {ShapeUtil::MakeShape(S32, {}),
   1267        ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});
   1268 
   1269   // Create while condition computation with 'loop_limit'.
   1270   const int32 loop_limit = 100;
   1271   Computation condition;
   1272   {
   1273     ComputationBuilder builder(client, "condition");
   1274     auto prev = builder.Parameter(0, loop_state_shape, "prev");
   1275     auto iteration = builder.GetTupleElement(prev, 0);
   1276     builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
   1277     condition = builder.Build().ConsumeValueOrDie();
   1278   }
   1279 
   1280   // Create while body computation with unit loop increment.
   1281   Computation body;
   1282   {
   1283     ComputationBuilder builder(client, "body");
   1284     auto prev = builder.Parameter(0, loop_state_shape, "prev");
   1285     // TupleElement 0
   1286     auto iteration = builder.GetTupleElement(prev, 0);
   1287     auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
   1288     // TupleElement 1
   1289     auto input = builder.GetTupleElement(prev, 1);
   1290     // Update.
   1291     auto one = builder.ConstantR0<float>(1.0);
   1292     auto update = builder.Broadcast(one, {1, 1024, 1024});
   1293     // Starts = iteration * 2;
   1294     auto starts = builder.ConstantR1<int32>({0, 0, 0});
   1295     // UpdateSlice.
   1296     auto out1 = builder.DynamicUpdateSlice(input, update, starts);
   1297     auto result = builder.Tuple({out0, out1});
   1298     body = builder.Build().ConsumeValueOrDie();
   1299   }
   1300 
   1301   // Create a While instruction.
   1302   ComputationBuilder builder(client, "while");
   1303   auto zero = builder.ConstantR0<float>(0.0);
   1304   auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
   1305   auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});
   1306   builder.While(condition, body, init);
   1307   auto computation = builder.Build().ConsumeValueOrDie();
   1308 
   1309   std::unique_ptr<LocalExecutable> executable =
   1310       client->Compile(computation, {}, ExecutableBuildOptions())
   1311           .ConsumeValueOrDie();
   1312 
   1313   // Run some warm-up executions.
   1314   ExecutableRunOptions options;
   1315   options.set_allocator(&allocator);
   1316   const int kWarmups = 2;
   1317   for (int i = 0; i < kWarmups; ++i) {
   1318     auto result = executable->Run({}, options);
   1319     ASSERT_TRUE(result.ok());
   1320   }
   1321 
   1322   // Run benchmark.
   1323   tensorflow::testing::StartTiming();
   1324   for (int i = 0; i < num_iters; ++i) {
   1325     auto result = executable->Run({}, options);
   1326     ASSERT_TRUE(result.ok());
   1327   }
   1328 }
   1329 
   1330 // TODO(b/32470510): Benchmark fails on parallel CPU backend.
   1331 #ifndef XLA_TEST_BACKEND_CPU_PARALLEL
   1332 BENCHMARK(BM_WhileLoop);
   1333 #endif
   1334 
   1335 }  // namespace
   1336 }  // namespace xla
   1337