1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/client/client_library.h" 21 #include "tensorflow/compiler/xla/client/computation.h" 22 #include "tensorflow/compiler/xla/client/computation_builder.h" 23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/compiler/xla/service/platform_util.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/status_macros.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 31 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 32 #include "tensorflow/compiler/xla/tests/test_macros.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/lib/core/status_test_util.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/test.h" 37 #include "tensorflow/core/platform/test_benchmark.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace se = ::perftools::gputools; 41 42 namespace xla { 43 namespace { 44 45 class WhileTest : public ClientLibraryTestBase {}; 46 47 // Tests a while node when the result type T is S32. 48 // 49 // int32 result = 0; 50 // while (result < 5) { 51 // result = result + 1; 52 // } 53 TEST_F(WhileTest, WhileWithScalarS32Result) { 54 auto result_shape = ShapeUtil::MakeShape(S32, {}); 55 56 // Create a computation for the condition: repeat for 5 iterations. 57 Computation condition; 58 { 59 ComputationBuilder builder(client_, "condition"); 60 auto prev = builder.Parameter(0, result_shape, "prev"); 61 builder.Gt(builder.ConstantR0<int32>(5), prev); 62 condition = builder.Build().ConsumeValueOrDie(); 63 } 64 65 // Create a computation for the body: add 1 to the result variable. 66 Computation body; 67 { 68 ComputationBuilder builder(client_, "body"); 69 auto prev = builder.Parameter(0, result_shape, "prev"); 70 auto input = builder.ConstantR0<int32>(1); 71 auto result = builder.Add(input, prev); 72 body = builder.Build().ConsumeValueOrDie(); 73 } 74 75 // Create a While node with computations for the condition and the body. 76 ComputationBuilder builder(client_, TestName()); 77 auto init = builder.ConstantR0<int32>(0); 78 auto result = builder.While(condition, body, init); 79 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 80 81 ComputeAndCompareR0<int32>(&builder, 5, {}); 82 } 83 84 // Tests a while node when the result type T is S64. 85 // 86 // int32 result = 0; 87 // while (result < 5) { 88 // result = result + 1; 89 // } 90 TEST_F(WhileTest, WhileWithScalarS64Result) { 91 auto result_shape = ShapeUtil::MakeShape(S64, {}); 92 93 // Create a computation for the condition: repeat for 5 iterations. 94 Computation condition; 95 { 96 ComputationBuilder builder(client_, "condition"); 97 auto prev = builder.Parameter(0, result_shape, "prev"); 98 builder.Gt(builder.ConstantR0<int64>(5), prev); 99 condition = builder.Build().ConsumeValueOrDie(); 100 } 101 102 // Create a computation for the body: add 1 to the result variable. 103 Computation body; 104 { 105 ComputationBuilder builder(client_, "body"); 106 auto prev = builder.Parameter(0, result_shape, "prev"); 107 auto input = builder.ConstantR0<int64>(1); 108 auto result = builder.Add(input, prev); 109 body = builder.Build().ConsumeValueOrDie(); 110 } 111 112 // Create a While node with computations for the condition and the body. 113 ComputationBuilder builder(client_, TestName()); 114 auto init = builder.ConstantR0<int64>(0); 115 auto result = builder.While(condition, body, init); 116 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 117 118 ComputeAndCompareR0<int64>(&builder, 5, {}); 119 } 120 121 TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { 122 auto result_shape = ShapeUtil::MakeShape(S32, {}); 123 auto orig_shape = ShapeUtil::MakeShape(S32, {2}); 124 125 // Create a computation for the condition: repeat for 5 iterations. 126 Computation condition; 127 { 128 ComputationBuilder builder(client_, "condition"); 129 auto prev = builder.Parameter(0, result_shape, "prev"); 130 builder.Gt(builder.ConstantR0<int32>(5), prev); 131 condition = builder.Build().ConsumeValueOrDie(); 132 } 133 134 // Create a computation for the body: add 1 to the result variable. 135 Computation body; 136 { 137 ComputationBuilder builder(client_, "body"); 138 auto prev = builder.Parameter(0, result_shape, "prev"); 139 auto input = builder.ConstantR0<int32>(1); 140 auto result = builder.Add(input, prev); 141 body = builder.Build().ConsumeValueOrDie(); 142 } 143 144 // Create a While node with computations for the condition and the body. 145 ComputationBuilder builder(client_, TestName()); 146 auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1), 147 builder.ConstantR0<int32>(0), 148 CreateScalarAddComputation(S32, &builder), {0}); 149 auto result = builder.While(condition, body, init); 150 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 151 152 ComputeAndCompareR0<int32>(&builder, 5, {}); 153 } 154 155 TEST_F(WhileTest, WhileWithPredicateResult) { 156 auto result_shape = ShapeUtil::MakeShape(PRED, {}); 157 158 // Create a computation for the condition: run until condition is true. 159 Computation condition; 160 { 161 ComputationBuilder builder(client_, "condition"); 162 auto prev = builder.Parameter(0, result_shape, "prev"); 163 builder.Ne(builder.ConstantR0<bool>(true), prev); 164 condition = builder.Build().ConsumeValueOrDie(); 165 } 166 167 // Create a computation for the body: or condition with true. 168 Computation body; 169 { 170 ComputationBuilder builder(client_, "body"); 171 auto prev = builder.Parameter(0, result_shape, "prev"); 172 auto result = builder.Or(prev, builder.ConstantR0<bool>(true)); 173 body = builder.Build().ConsumeValueOrDie(); 174 } 175 176 // Create a While node with computations for the condition and the body. 177 ComputationBuilder builder(client_, TestName()); 178 auto init = builder.Ne(builder.ConstantR0<bool>(false), 179 builder.ConstantR0<bool>(true)); 180 auto result = builder.While(condition, body, init); 181 182 ComputeAndCompareR0<bool>(&builder, true, {}); 183 } 184 185 // Tests a while node when the result type T is a vector. 186 // 187 // All constants are chosen to produce exact results. 188 // vector<float> result(0); 189 // while (result.sum() < 15.5f) { 190 // result = result + vector<float>(0); 191 // } 192 // TODO(b/29185393): does not terminate on CPU. 193 TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { 194 Shape result_shape = ShapeUtil::MakeShape(F32, {0}); 195 196 // Create a computation for the reduction. 197 Computation add; 198 { 199 ComputationBuilder builder(client_, "add"); 200 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 201 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 202 builder.Add(x, y); 203 add = builder.Build().ConsumeValueOrDie(); 204 } 205 206 // Create a computation for the condition. 207 // Repeat until the sum of the result vector is less than 15.5f. 208 Computation condition; 209 { 210 ComputationBuilder builder(client_, "condition"); 211 auto prev = builder.Parameter(0, result_shape, "prev"); 212 auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, 213 /*dimensions_to_reduce=*/{0}); 214 auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); 215 condition = builder.Build().ConsumeValueOrDie(); 216 } 217 218 // Create a computation for the body. 219 // Add a constant vector of 1.f to the result vector. 220 Computation body; 221 { 222 ComputationBuilder builder(client_, "body"); 223 auto prev = builder.Parameter(0, result_shape, "prev"); 224 auto input = builder.ConstantR1<float>({}); 225 auto result = builder.Add(input, prev); 226 body = builder.Build().ConsumeValueOrDie(); 227 } 228 229 // Create a While node with computations for the condition and the body. 230 ComputationBuilder builder(client_, "while"); 231 auto init = builder.ConstantR1<float>({}); 232 auto result = builder.While(condition, body, init); 233 VLOG(2) << "while = " << ShapeUtil::HumanString( 234 *builder.GetShape(result).ConsumeValueOrDie()); 235 236 ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001)); 237 } 238 239 // Tests a while node when the result type T is a vector. 240 // 241 // All constants are chosen to produce exact results. 242 // vector<float> result(8, 0.0f); 243 // while (result.sum() < 15.5f) { 244 // result = result + vector<float>(8, 0.125f); 245 // } 246 TEST_F(WhileTest, WhileWithVectorResult) { 247 Shape result_shape = ShapeUtil::MakeShape(F32, {8}); 248 249 // Create a computation for the reduction. 250 Computation add; 251 { 252 ComputationBuilder builder(client_, "add"); 253 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 254 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 255 builder.Add(x, y); 256 add = builder.Build().ConsumeValueOrDie(); 257 } 258 259 // Create a computation for the condition. 260 // Repeat until the sum of the result vector is less than 5.5f. 261 Computation condition; 262 { 263 ComputationBuilder builder(client_, "condition"); 264 auto prev = builder.Parameter(0, result_shape, "prev"); 265 auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, 266 /*dimensions_to_reduce=*/{0}); 267 auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); 268 condition = builder.Build().ConsumeValueOrDie(); 269 } 270 271 // Create a computation for the body. 272 // Add a constant vector of 1.f to the result vector. 273 Computation body; 274 { 275 ComputationBuilder builder(client_, "body"); 276 auto prev = builder.Parameter(0, result_shape, "prev"); 277 auto input = builder.ConstantR1<float>(8, 0.125f); 278 auto result = builder.Add(input, prev); 279 body = builder.Build().ConsumeValueOrDie(); 280 } 281 282 // Create a While node with computations for the condition and the body. 283 ComputationBuilder builder(client_, "while"); 284 auto init = builder.ConstantR1<float>(8, 0.f); 285 auto result = builder.While(condition, body, init); 286 VLOG(2) << "while = " << ShapeUtil::HumanString( 287 *builder.GetShape(result).ConsumeValueOrDie()); 288 289 // Individual elements with increase by 1/8 each time through the loop, so 290 // the sum will increase by 1.0. It will first be >15.5 when the elements 291 // have all reached 2.0. 292 std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; 293 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 294 } 295 296 // Tests a while node when the result type is a vector which is part 297 // of the result tuple. 298 // 299 // All constants are chosen to produce exact results. 300 // vector<float> result(8, 0.0f); 301 // while (result.sum() < 15.5f) { 302 // result = result + vector<float>(8, 0.125f); 303 // } 304 // tuple = tuple { while } 305 TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { 306 Shape result_shape = ShapeUtil::MakeShape(F32, {8}); 307 308 // Create a computation for the reduction. 309 Computation add; 310 { 311 ComputationBuilder builder(client_, "add"); 312 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 313 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 314 builder.Add(x, y); 315 add = builder.Build().ConsumeValueOrDie(); 316 } 317 318 // Create a computation for the condition. 319 // Repeat until the sum of the result vector is less than 5.5f. 320 Computation condition; 321 { 322 ComputationBuilder builder(client_, "condition"); 323 auto prev = builder.Parameter(0, result_shape, "prev"); 324 auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, 325 /*dimensions_to_reduce=*/{0}); 326 auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); 327 condition = builder.Build().ConsumeValueOrDie(); 328 } 329 330 // Create a computation for the body. 331 // Add a constant vector of 1.f to the result vector. 332 Computation body; 333 { 334 ComputationBuilder builder(client_, "body"); 335 auto prev = builder.Parameter(0, result_shape, "prev"); 336 auto input = builder.ConstantR1<float>(8, 0.125f); 337 auto result = builder.Add(input, prev); 338 body = builder.Build().ConsumeValueOrDie(); 339 } 340 341 // Create a While node with computations for the condition and the body. 342 ComputationBuilder builder(client_, "while"); 343 auto init = builder.ConstantR1<float>(8, 0.f); 344 auto result = builder.While(condition, body, init); 345 VLOG(2) << "while = " 346 << ShapeUtil::HumanString( 347 *builder.GetShape(result).ConsumeValueOrDie()); 348 builder.Tuple({result}); 349 350 // Individual elements with increase by 1/8 each time through the loop, so 351 // the sum will increase by 1.0. It will first be >15.5 when the elements 352 // have all reached 2.0. 353 auto expected_data = 354 Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); 355 auto expected = Literal::MakeTuple({expected_data.get()}); 356 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); 357 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); 358 } 359 360 TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { 361 std::vector<Shape> shape_elements = { 362 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), 363 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; 364 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 365 366 // Create a computation for the condition. 367 // Repeat for N iterations. 368 const int N = 2; 369 Computation condition; 370 { 371 ComputationBuilder builder(client_, "condition"); 372 auto prev = builder.Parameter(0, result_shape, "prev"); 373 auto iteration = builder.GetTupleElement(prev, 0); 374 builder.Gt(builder.ConstantR0<int32>(N), iteration); 375 condition = builder.Build().ConsumeValueOrDie(); 376 } 377 378 // Create a computation for the body. 379 // Add 1 to the iteration variable and permute the weights. 380 Computation body; 381 { 382 ComputationBuilder builder(client_, "body"); 383 auto prev = builder.Parameter(0, result_shape, "prev"); 384 auto iteration = builder.GetTupleElement(prev, 0); 385 auto w1 = builder.GetTupleElement(prev, 1); 386 auto w2 = builder.GetTupleElement(prev, 2); 387 auto w3 = builder.GetTupleElement(prev, 3); 388 auto result = builder.Tuple( 389 {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); 390 body = builder.Build().ConsumeValueOrDie(); 391 } 392 393 // Create a While node with computations for the condition and the body. 394 ComputationBuilder builder(client_, "while"); 395 auto init = builder.Tuple( 396 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), 397 builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); 398 auto result = builder.While(condition, body, init); 399 VLOG(2) << "result = " 400 << ShapeUtil::HumanString( 401 *builder.GetShape(result).ConsumeValueOrDie()); 402 403 auto expected_counter = Literal::CreateR0<int32>(N); 404 auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f}); 405 auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f}); 406 auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f}); 407 auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), 408 expected_w3.get(), expected_w1.get()}); 409 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); 410 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); 411 } 412 413 TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { 414 std::vector<Shape> shape_elements = { 415 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), 416 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; 417 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 418 419 // Create a computation for the condition. 420 // Repeat for N iterations. 421 const int N = 2; 422 Computation condition; 423 { 424 ComputationBuilder builder(client_, "condition"); 425 auto prev = builder.Parameter(0, result_shape, "prev"); 426 auto iteration = builder.GetTupleElement(prev, 0); 427 builder.Gt(builder.ConstantR0<int32>(N), iteration); 428 condition = builder.Build().ConsumeValueOrDie(); 429 } 430 431 // Create a computation for the body. 432 // Add 1 to the iteration variable permute the weights. 433 Computation body; 434 { 435 ComputationBuilder builder(client_, "body"); 436 auto prev = builder.Parameter(0, result_shape, "prev"); 437 auto iteration = builder.GetTupleElement(prev, 0); 438 auto w1 = builder.GetTupleElement(prev, 1); 439 auto w2 = builder.GetTupleElement(prev, 2); 440 auto w3 = builder.GetTupleElement(prev, 3); 441 auto result = builder.Tuple( 442 {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); 443 body = builder.Build().ConsumeValueOrDie(); 444 } 445 446 // Create a While node with computations for the condition and the body. 447 ComputationBuilder builder(client_, "while"); 448 auto init = builder.Tuple( 449 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), 450 builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); 451 auto xla_while = builder.While(condition, body, init); 452 453 auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), 454 builder.GetTupleElement(xla_while, 2)); 455 auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); 456 VLOG(2) << "result = " 457 << ShapeUtil::HumanString( 458 *builder.GetShape(result).ConsumeValueOrDie()); 459 std::vector<float> expected = {6.f, 6.f, 6.f}; 460 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 461 } 462 463 // Tests a while node when the result type T is a Tuple. 464 // 465 // tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f)); 466 // while (get<0>(result) < 5) { 467 // get<0>(result) = get<0>(result) + 1; 468 // get<1>(result) = get<1>(result) + vector<float>(10, 1.0f); 469 // } 470 TEST_F(WhileTest, WhileWithTupleResult) { 471 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 472 ShapeUtil::MakeShape(F32, {10})}; 473 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 474 475 // Create a computation for the condition. 476 // Repeat for 5 iterations. 477 Computation condition; 478 { 479 ComputationBuilder builder(client_, "condition"); 480 auto prev = builder.Parameter(0, result_shape, "prev"); 481 auto iteration = builder.GetTupleElement(prev, 0); 482 builder.Gt(builder.ConstantR0<int32>(5), iteration); 483 condition = builder.Build().ConsumeValueOrDie(); 484 } 485 486 // Create a computation for the body. 487 // Add 1 to the iteration variable and add a constant vector of 1.0f to 488 // the weight variable, both of which are tuple elements. 489 Computation body; 490 { 491 ComputationBuilder builder(client_, "body"); 492 auto prev = builder.Parameter(0, result_shape, "prev"); 493 auto iteration = builder.GetTupleElement(prev, 0); 494 auto weights = builder.GetTupleElement(prev, 1); 495 auto input = builder.ConstantR1<float>(10, 1.f); 496 auto new_weights = builder.Add(weights, input); 497 auto result = builder.Tuple( 498 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); 499 body = builder.Build().ConsumeValueOrDie(); 500 } 501 502 // Create a While node with computations for the condition and the body. 503 ComputationBuilder builder(client_, "while"); 504 auto init = builder.Tuple( 505 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); 506 auto result = builder.While(condition, body, init); 507 VLOG(2) << "while = " << ShapeUtil::HumanString( 508 *builder.GetShape(result).ConsumeValueOrDie()); 509 510 auto expected_counter = Literal::CreateR0<int32>(5); 511 auto expected_data = Literal::CreateR1<float>( 512 {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); 513 auto expected = 514 Literal::MakeTuple({expected_counter.get(), expected_data.get()}); 515 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); 516 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); 517 } 518 519 TEST_F(WhileTest, WhileWithPredicateTupleResult) { 520 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 521 ShapeUtil::MakeShape(PRED, {})}; 522 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 523 524 // Create a computation for the condition. 525 // Repeat for 5 iterations. 526 Computation condition; 527 { 528 ComputationBuilder builder(client_, "condition"); 529 auto prev = builder.Parameter(0, result_shape, "prev"); 530 auto iteration = builder.GetTupleElement(prev, 0); 531 builder.Gt(builder.ConstantR0<int32>(5), iteration); 532 condition = builder.Build().ConsumeValueOrDie(); 533 } 534 535 // Create a computation for the body. 536 // Add 1 to the iteration variable and or the predicate with true 537 Computation body; 538 { 539 ComputationBuilder builder(client_, "body"); 540 auto prev = builder.Parameter(0, result_shape, "prev"); 541 auto iteration = builder.GetTupleElement(prev, 0); 542 auto pred = builder.GetTupleElement(prev, 1); 543 auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true)); 544 auto result = builder.Tuple( 545 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred}); 546 body = builder.Build().ConsumeValueOrDie(); 547 } 548 549 // Create a While node with computations for the condition and the body. 550 ComputationBuilder builder(client_, "while"); 551 auto init = builder.Tuple({builder.ConstantR0<int32>(0), 552 builder.Ne(builder.ConstantR0<bool>(false), 553 builder.ConstantR0<bool>(true))}); 554 auto result = builder.While(condition, body, init); 555 VLOG(2) << "while = " 556 << ShapeUtil::HumanString( 557 *builder.GetShape(result).ConsumeValueOrDie()); 558 559 auto expected_counter = Literal::CreateR0<int32>(5); 560 auto expected_predicate = Literal::CreateR0<bool>(true); 561 auto expected = 562 Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); 563 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); 564 } 565 566 TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { 567 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 568 ShapeUtil::MakeShape(S32, {})}; 569 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 570 571 // Create a computation for the condition. 572 // Repeat for 5 iterations. 573 Computation condition; 574 { 575 ComputationBuilder builder(client_, "condition"); 576 auto prev = builder.Parameter(0, result_shape, "prev"); 577 auto iteration = builder.GetTupleElement(prev, 0); 578 builder.Gt(builder.ConstantR0<int32>(5), iteration); 579 condition = builder.Build().ConsumeValueOrDie(); 580 } 581 582 // Create a computation for the body. 583 // Add 1 to the iteration variable and set the other tuple element to a 584 // constant. 585 Computation body; 586 { 587 ComputationBuilder builder(client_, "body"); 588 auto prev = builder.Parameter(0, result_shape, "prev"); 589 auto iteration = builder.GetTupleElement(prev, 0); 590 auto result = 591 builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), 592 builder.ConstantR0<int32>(7)}); 593 body = builder.Build().ConsumeValueOrDie(); 594 } 595 596 // Create a While node with computations for the condition and the body. 597 ComputationBuilder builder(client_, "while"); 598 auto init = builder.Tuple( 599 {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)}); 600 auto result = builder.While(condition, body, init); 601 VLOG(2) << "while = " 602 << ShapeUtil::HumanString( 603 *builder.GetShape(result).ConsumeValueOrDie()); 604 605 auto expected_counter = Literal::CreateR0<int32>(5); 606 auto expected_data = Literal::CreateR0<int32>(7); 607 auto expected = 608 Literal::MakeTuple({expected_counter.get(), expected_data.get()}); 609 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); 610 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); 611 } 612 613 // Tests two while nodes when the result type T is a Tuple and the second 614 // while node uses the result of the first while node which is used in two 615 // nodes. 616 // tuple<int32, vector<float>> w0(0, vector<float>(10, 0.0f)); 617 // w0 = while (get<0>(w0) < c1) { 618 // get<0>(w0) = get<0>(w0) + 1; 619 // get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f); 620 // } 621 // tuple<int32, vector<float>> w1(get<0>(w0), get<1>(w0)); 622 // w1 = while (get<0>(w1) < c2) { 623 // get<0>(w1) = get<0>(w1) + 1; 624 // get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f); 625 // } 626 // result = get<1>(w0) + get<1>(w1) 627 TEST_F(WhileTest, TwoWhileWithTupleResult) { 628 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 629 ShapeUtil::MakeShape(F32, {10})}; 630 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 631 632 // Create a computation for the condition. 633 // Repeat for 5 iterations. 634 Computation condition; 635 const int c1 = 5; 636 { 637 ComputationBuilder builder(client_, "condition"); 638 auto prev = builder.Parameter(0, result_shape, "prev"); 639 auto iteration = builder.GetTupleElement(prev, 0); 640 builder.Lt(iteration, builder.ConstantR0<int32>(c1)); 641 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); 642 } 643 644 Computation condition2; 645 const int c2 = 7; 646 { 647 ComputationBuilder builder(client_, "condition2"); 648 auto prev = builder.Parameter(0, result_shape, "prev"); 649 auto iteration = builder.GetTupleElement(prev, 0); 650 builder.Lt(iteration, builder.ConstantR0<int32>(c2)); 651 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); 652 } 653 654 // Create a computation for the body. 655 // Add 1 to the iteration variable and add a constant vector of 1.0f to 656 // the weight variable, both of which are tuple elements. 657 Computation body; 658 { 659 ComputationBuilder builder(client_, "body"); 660 auto prev = builder.Parameter(0, result_shape, "prev"); 661 auto iteration = builder.GetTupleElement(prev, 0); 662 auto weights = builder.GetTupleElement(prev, 1); 663 auto input = builder.ConstantR1<float>(10, 1.f); 664 auto new_weights = builder.Add(weights, input); 665 auto result = builder.Tuple( 666 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); 667 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); 668 } 669 670 Computation body2; 671 { 672 ComputationBuilder builder(client_, "body"); 673 auto prev = builder.Parameter(0, result_shape, "prev"); 674 auto iteration = builder.GetTupleElement(prev, 0); 675 auto weights = builder.GetTupleElement(prev, 1); 676 auto input = builder.ConstantR1<float>(10, 1.f); 677 auto new_weights = builder.Add(weights, input); 678 auto result = builder.Tuple( 679 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); 680 TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); 681 } 682 683 // Create a While node with computations for the condition and the body. 684 ComputationBuilder builder(client_, "while"); 685 auto init = builder.Tuple( 686 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); 687 auto while1 = builder.While(condition, body, init); 688 689 auto while2 = builder.While(condition2, body2, while1); 690 691 auto while_result1 = builder.GetTupleElement(while1, 1); 692 auto while_result2 = builder.GetTupleElement(while2, 1); 693 VLOG(2) << "while_result2 = " 694 << ShapeUtil::HumanString( 695 *builder.GetShape(while_result2).ConsumeValueOrDie()); 696 auto result = builder.Add(while_result1, while_result2); 697 VLOG(2) << "result = " 698 << ShapeUtil::HumanString( 699 *builder.GetShape(result).ConsumeValueOrDie()); 700 const float sum = c1 + c2; 701 std::vector<float> expected(10, sum); 702 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 703 } 704 705 // Test while nodes that share the while body computation. 706 TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { 707 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 708 ShapeUtil::MakeShape(F32, {10})}; 709 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 710 711 // Create a computation for the condition. 712 // Repeat for 5 iterations. 713 Computation condition; 714 const int c1 = 5; 715 { 716 ComputationBuilder builder(client_, "condition"); 717 auto prev = builder.Parameter(0, result_shape, "prev"); 718 auto iteration = builder.GetTupleElement(prev, 0); 719 builder.Lt(iteration, builder.ConstantR0<int32>(c1)); 720 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); 721 } 722 723 Computation condition2; 724 const int c2 = 7; 725 { 726 ComputationBuilder builder(client_, "condition2"); 727 auto prev = builder.Parameter(0, result_shape, "prev"); 728 auto iteration = builder.GetTupleElement(prev, 0); 729 builder.Lt(iteration, builder.ConstantR0<int32>(c2)); 730 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); 731 } 732 733 // Create a computation for the body. 734 // Add 1 to the iteration variable and add a constant vector of 1.0f to 735 // the weight variable, both of which are tuple elements. 736 Computation body; 737 { 738 ComputationBuilder builder(client_, "body"); 739 auto prev = builder.Parameter(0, result_shape, "prev"); 740 auto iteration = builder.GetTupleElement(prev, 0); 741 auto weights = builder.GetTupleElement(prev, 1); 742 auto input = builder.ConstantR1<float>(10, 1.f); 743 auto new_weights = builder.Add(weights, input); 744 auto result = builder.Tuple( 745 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); 746 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); 747 } 748 749 // Create a While node with computations for the condition and the body. 750 ComputationBuilder builder(client_, "while"); 751 auto init = builder.Tuple( 752 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); 753 auto while1 = builder.While(condition, body, init); 754 755 auto while2 = builder.While(condition2, body, while1); 756 757 auto while_result1 = builder.GetTupleElement(while1, 1); 758 auto while_result2 = builder.GetTupleElement(while2, 1); 759 VLOG(2) << "while_result2 = " 760 << ShapeUtil::HumanString( 761 *builder.GetShape(while_result2).ConsumeValueOrDie()); 762 auto result = builder.Add(while_result1, while_result2); 763 VLOG(2) << "result = " 764 << ShapeUtil::HumanString( 765 *builder.GetShape(result).ConsumeValueOrDie()); 766 const float sum = c1 + c2; 767 std::vector<float> expected(10, sum); 768 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 769 } 770 771 // Test while nodes that share the while body computation. 772 // TODO(b/37245345): Fails on GPU backend. 773 TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { 774 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 775 ShapeUtil::MakeShape(F32, {10})}; 776 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 777 778 // Create a computation for the condition. 779 // Repeat for 5 iterations. 780 Computation condition; 781 const int c1 = 5; 782 { 783 ComputationBuilder builder(client_, "condition"); 784 auto prev = builder.Parameter(0, result_shape, "prev"); 785 auto iteration = builder.GetTupleElement(prev, 0); 786 builder.Lt(iteration, builder.ConstantR0<int32>(c1)); 787 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); 788 } 789 790 Computation condition2; 791 const int c2 = 7; 792 { 793 ComputationBuilder builder(client_, "condition2"); 794 auto prev = builder.Parameter(0, result_shape, "prev"); 795 auto iteration = builder.GetTupleElement(prev, 0); 796 builder.Lt(iteration, builder.ConstantR0<int32>(c2)); 797 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); 798 } 799 800 // Create a computation for the body. 801 // Add 1 to the iteration variable and add a constant vector of 1.0f to 802 // the weight variable, both of which are tuple elements. 803 Computation body; 804 { 805 ComputationBuilder builder(client_, "body"); 806 auto prev = builder.Parameter(0, result_shape, "prev"); 807 auto iteration = builder.GetTupleElement(prev, 0); 808 auto weights = builder.GetTupleElement(prev, 1); 809 auto input = builder.ConstantR1<float>(10, 1.f); 810 auto new_weights = builder.Add(weights, input); 811 auto result = builder.Tuple( 812 {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); 813 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); 814 } 815 816 // Create a While node with computations for the condition and the body. 817 ComputationBuilder builder(client_, "while"); 818 auto init = builder.Tuple( 819 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); 820 auto while1 = builder.While(condition, body, init); 821 auto while2 = builder.While(condition2, body, init); 822 823 auto while_result1 = builder.GetTupleElement(while1, 1); 824 auto while_result2 = builder.GetTupleElement(while2, 1); 825 VLOG(2) << "while_result2 = " 826 << ShapeUtil::HumanString( 827 *builder.GetShape(while_result2).ConsumeValueOrDie()); 828 auto result = builder.Add(while_result1, while_result2); 829 VLOG(2) << "result = " 830 << ShapeUtil::HumanString( 831 *builder.GetShape(result).ConsumeValueOrDie()); 832 const float sum = c1 + c2; 833 std::vector<float> expected(10, sum); 834 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 835 } 836 837 // WhileTest that uses DynamicUpdateSlice instruction in body computation. 838 // Loop state tuple element 1 has as its single user operand(0) of 839 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU. 840 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { 841 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), 842 ShapeUtil::MakeShape(F32, {10})}; 843 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); 844 845 // Create a computation for the condition. 846 // Repeat for 5 iterations. 847 Computation condition; 848 { 849 ComputationBuilder builder(client_, "condition"); 850 auto prev = builder.Parameter(0, result_shape, "prev"); 851 auto iteration = builder.GetTupleElement(prev, 0); 852 builder.Gt(builder.ConstantR0<int32>(5), iteration); 853 condition = builder.Build().ConsumeValueOrDie(); 854 } 855 856 // Create a computation for the body. 857 // Add 1 to the iteration variable and add a constant vector of 1.0f to 858 // the weight variable, both of which are tuple elements. 859 Computation body; 860 { 861 ComputationBuilder builder(client_, "body"); 862 auto prev = builder.Parameter(0, result_shape, "prev"); 863 // TupleElement 0 864 auto iteration = builder.GetTupleElement(prev, 0); 865 auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1)); 866 // TupleElement 1 867 auto input = builder.GetTupleElement(prev, 1); 868 // Update. 869 auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); 870 // Starts = iteration * 2; 871 auto starts = builder.Reshape( 872 builder.Mul(iteration, builder.ConstantR0<int32>(2)), {1}); 873 // UpdateSlice. 874 auto out1 = builder.DynamicUpdateSlice(input, update, starts); 875 876 auto result = builder.Tuple({out0, out1}); 877 body = builder.Build().ConsumeValueOrDie(); 878 } 879 880 // Create a While node with computations for the condition and the body. 881 ComputationBuilder builder(client_, "while"); 882 auto init = builder.Tuple( 883 {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); 884 auto result = builder.While(condition, body, init); 885 VLOG(2) << "while = " 886 << ShapeUtil::HumanString( 887 *builder.GetShape(result).ConsumeValueOrDie()); 888 889 auto expected_counter = Literal::CreateR0<int32>(5); 890 auto expected_data = Literal::CreateR1<float>( 891 {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); 892 auto expected = 893 Literal::MakeTuple({expected_counter.get(), expected_data.get()}); 894 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); 895 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); 896 } 897 898 // Tests a while node when the result type T is a vector of S32. 899 // 900 // int32 result = (0, 0, 0, 0, 0, 0); 901 // while (result[0] < count) { 902 // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]); 903 // } 904 // 905 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a 906 // pair: 907 // ((iteration, (random vector))). 908 // 909 // Note: this test currently only tests generating random values within a loop. 910 // Per backend the values generated can be different as the different backends 911 // use different random number generators. 912 // TODO(b/32240857): Extend test to verify outputs. 913 TEST_F(WhileTest, WhileWithPrngScalarResult) { 914 auto v6s32 = ShapeUtil::MakeShape(S32, {6}); 915 916 // Create a computation for the condition: repeat for count iterations. 917 auto build_condition = [this, v6s32](int count) { 918 ComputationBuilder builder(client_, TestName()); 919 auto prev = builder.Reshape( 920 builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, 921 {}); 922 builder.Gt(builder.ConstantR0<int32>(count), prev); 923 return builder.Build().ConsumeValueOrDie(); 924 }; 925 926 // Create a computation for the body: add 1 to the result variable. 927 Computation body; 928 { 929 ComputationBuilder builder(client_, "body"); 930 auto prev = builder.Parameter(0, v6s32, "prev"); 931 auto inc = builder.ConcatInDim( 932 {builder.ConstantR1<int32>({1}), 933 builder.RngUniform(builder.ConstantR0<int32>(0), 934 builder.ConstantR0<int32>(100), 935 ShapeUtil::MakeShape(S32, {5}))}, 936 0); 937 auto result = builder.Add(inc, prev); 938 body = builder.Build().ConsumeValueOrDie(); 939 } 940 941 // Create a While node with computations for the condition and the body. 942 auto while_loop = [this, &body, build_condition](int count) { 943 ComputationBuilder builder(client_, TestName()); 944 auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0}); 945 auto result = builder.While(build_condition(count), body, init); 946 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 947 return builder.Build(); 948 }; 949 950 for (int i = 1; i < 4; ++i) { 951 TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i)); 952 953 ExecutionOptions execution_options = execution_options_; 954 execution_options.set_seed(65); 955 TF_ASSERT_OK_AND_ASSIGN( 956 auto result, 957 client_->ExecuteAndTransfer(computation, {}, &execution_options)); 958 } 959 } 960 961 TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { 962 auto element_shape = ShapeUtil::MakeShape(F32, {2}); 963 964 ComputationBuilder outer(client_, "outer"); 965 auto p = outer.Parameter(0, element_shape, "param"); 966 auto t = outer.Tuple({p, outer.ConstantR1<float>({1, 1})}); 967 968 TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Shape> tuple_shape, 969 outer.GetShape(t)); 970 971 ComputationBuilder cond(client_, "cond"); 972 auto cond_t = cond.Parameter(0, *tuple_shape, "t"); 973 TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), 974 cond.ConstantR1<float>({42, 42})), 975 &cond) 976 .status()); 977 978 ComputationBuilder body(client_, "body"); 979 auto body_t = body.Parameter(0, *tuple_shape, "t"); 980 auto e = body.GetTupleElement(body_t, 1); 981 body.Tuple({e, e}); 982 983 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); 984 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); 985 outer.While(cond_computation, body_computation, t); 986 987 auto expected_element = Literal::CreateR1<float>({1, 1}); 988 auto expected = 989 Literal::MakeTuple({expected_element.get(), expected_element.get()}); 990 TF_ASSERT_OK_AND_ASSIGN( 991 std::unique_ptr<GlobalData> parameter_data, 992 client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); 993 ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, 994 ErrorSpec(1e-6)); 995 } 996 997 TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { 998 auto element_shape = ShapeUtil::MakeShape(F32, {2}); 999 1000 ComputationBuilder outer(client_, "outer"); 1001 auto p = outer.Parameter(0, element_shape, "param"); 1002 1003 ComputationBuilder cond(client_, "cond"); 1004 auto cond_t = cond.Parameter(0, element_shape, "t"); 1005 TF_ASSERT_OK( 1006 Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status()); 1007 1008 ComputationBuilder body(client_, "body"); 1009 auto body_t = body.Parameter(0, element_shape, "t"); 1010 auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2}); 1011 1012 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); 1013 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); 1014 outer.While(cond_computation, body_computation, p); 1015 1016 TF_ASSERT_OK_AND_ASSIGN( 1017 std::unique_ptr<GlobalData> parameter_data, 1018 client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); 1019 ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()}, 1020 ErrorSpec(1e-6)); 1021 } 1022 1023 TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { 1024 auto element_shape = ShapeUtil::MakeShape(F32, {}); 1025 1026 ComputationBuilder outer(client_, "outer"); 1027 auto p = outer.Parameter(0, element_shape, "param"); 1028 1029 ComputationBuilder cond(client_, "cond"); 1030 auto cond_t = cond.Parameter(0, element_shape, "t"); 1031 cond.Eq(cond_t, cond.ConstantR0<float>(42)); 1032 1033 ComputationBuilder body(client_, "body"); 1034 auto body_t = body.Parameter(0, element_shape, "t"); 1035 auto tuple = 1036 body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))}); 1037 auto e = body.GetTupleElement(tuple, 1); 1038 1039 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); 1040 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); 1041 outer.While(cond_computation, body_computation, p); 1042 1043 TF_ASSERT_OK_AND_ASSIGN( 1044 std::unique_ptr<GlobalData> parameter_data, 1045 client_->TransferToServer(*Literal::CreateR0<float>(42))); 1046 ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()}, 1047 ErrorSpec(1e-6)); 1048 } 1049 1050 // Tests loop where the init value comes from two sources (constant and 1051 // parameter). 1052 // 1053 // int32 result = (0, 1); 1054 // while (result[0] + result[1] < 30) { 1055 // result[0] = result[0] + 1; 1056 // result[1] = result[1] + 1; 1057 // } 1058 TEST_F(WhileTest, WhileWithMixedTupleElements) { 1059 auto result_shape = ShapeUtil::MakeTupleShape( 1060 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); 1061 1062 ComputationBuilder outer(client_, "outer"); 1063 auto p = 1064 outer.Tuple({outer.ConstantR0<int32>(0), 1065 outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); 1066 1067 ComputationBuilder cond(client_, "cond"); 1068 auto params = cond.Parameter(0, result_shape, "prev"); 1069 auto cond_t = cond.Add(cond.GetTupleElement(params, 1), 1070 cond.GetTupleElement(params, 0)); 1071 cond.Lt(cond_t, cond.ConstantR0<int32>(30)); 1072 1073 ComputationBuilder body(client_, "body"); 1074 auto body_t = body.Parameter(0, result_shape, "t"); 1075 1076 auto tuple = body.Tuple( 1077 {body.Add(body.GetTupleElement(params, 0), body.ConstantR0<int32>(1)), 1078 body.Add(body.GetTupleElement(params, 1), body.ConstantR0<int32>(1))}); 1079 1080 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); 1081 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); 1082 outer.While(cond_computation, body_computation, p); 1083 1084 TF_ASSERT_OK_AND_ASSIGN( 1085 std::unique_ptr<GlobalData> parameter_data, 1086 client_->TransferToServer(*Literal::CreateR0<int32>(1))); 1087 1088 auto add1 = Literal::CreateR0<int32>(15); 1089 auto add2 = Literal::CreateR0<int32>(16); 1090 auto expected = Literal::MakeTuple({add1.get(), add2.get()}); 1091 ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, 1092 ErrorSpec(1e-6)); 1093 } 1094 1095 // Tests nested while loops. 1096 // 1097 // int32 result = 0; 1098 // while (result < 30) { 1099 // int i = 0; 1100 // while (i < 7) { 1101 // result = result + 2; 1102 // i = i + 1; 1103 // } 1104 // } 1105 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { 1106 auto outer_result_shape = ShapeUtil::MakeShape(S32, {}); 1107 auto inner_result_shape = ShapeUtil::MakeTupleShape( 1108 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); 1109 1110 Computation inner_condition; 1111 { 1112 ComputationBuilder builder(client_, "inner_condition"); 1113 auto params = builder.Parameter(0, inner_result_shape, "prev"); 1114 auto i = builder.GetTupleElement(params, 0); 1115 builder.Lt(i, builder.ConstantR0<int32>(7)); 1116 inner_condition = builder.Build().ConsumeValueOrDie(); 1117 } 1118 1119 // Creates a computation for the outer loop condition: 1120 // repeat while result < 30. 1121 Computation outer_condition; 1122 { 1123 ComputationBuilder builder(client_, "outer_condition"); 1124 auto prev = builder.Parameter(0, outer_result_shape, "prev"); 1125 builder.Lt(prev, builder.ConstantR0<int32>(30)); 1126 outer_condition = builder.Build().ConsumeValueOrDie(); 1127 } 1128 1129 // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to 1130 // `result`. 1131 Computation inner_body; 1132 { 1133 ComputationBuilder builder(client_, "inner_body"); 1134 auto params = builder.Parameter(0, inner_result_shape, "prev"); 1135 auto i = builder.GetTupleElement(params, 0); 1136 auto result = builder.GetTupleElement(params, 1); 1137 i = builder.Add(builder.ConstantR0<int32>(1), i); 1138 result = builder.Add(builder.ConstantR0<int32>(2), result); 1139 auto output = builder.Tuple({i, result}); 1140 inner_body = builder.Build().ConsumeValueOrDie(); 1141 } 1142 1143 // Creates a computation for the outer loop: run the inner loop with i = 0. 1144 Computation outer_body; 1145 { 1146 ComputationBuilder builder(client_, "outer_body"); 1147 auto prev = builder.Parameter(0, outer_result_shape, "prev"); 1148 auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev}); 1149 auto result = builder.While(inner_condition, inner_body, init); 1150 auto output = builder.GetTupleElement(result, 1); 1151 outer_body = builder.Build().ConsumeValueOrDie(); 1152 } 1153 1154 // Create a While node with computations for the condition and the body. 1155 ComputationBuilder builder(client_, TestName()); 1156 auto init = builder.ConstantR0<int32>(0); 1157 auto result = builder.While(outer_condition, outer_body, init); 1158 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 1159 1160 ComputeAndCompareR0<int32>(&builder, 42, {}); 1161 } 1162 1163 // Tests a while node when the result type T is S32. 1164 // f = lambda result: tuple({result < 5}) 1165 // int32 result = 0; 1166 // while (f(result).get<0>()) { 1167 // result = result + 1; 1168 // } 1169 TEST_F(WhileTest, WhileWithCallInsideCondition) { 1170 auto result_shape = ShapeUtil::MakeShape(S32, {}); 1171 1172 // Create a computation for the condition: repeat for 5 iterations. 1173 Computation condition_callee; 1174 { 1175 ComputationBuilder builder(client_, "condition_callee"); 1176 auto prev = builder.Parameter(0, result_shape, "prev"); 1177 builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)}); 1178 1179 condition_callee = builder.Build().ConsumeValueOrDie(); 1180 } 1181 1182 Computation condition; 1183 { 1184 ComputationBuilder builder(client_, "condition"); 1185 auto prev = builder.Parameter(0, result_shape, "prev"); 1186 auto result = builder.Call(condition_callee, {prev}); 1187 builder.GetTupleElement(result, 0); 1188 condition = builder.Build().ConsumeValueOrDie(); 1189 } 1190 1191 // Create a computation for the body: add 1 to the result variable. 1192 Computation body; 1193 { 1194 ComputationBuilder builder(client_, "body"); 1195 auto prev = builder.Parameter(0, result_shape, "prev"); 1196 auto input = builder.ConstantR0<int32>(1); 1197 auto result = builder.Add(input, prev); 1198 body = builder.Build().ConsumeValueOrDie(); 1199 } 1200 1201 // Create a While node with computations for the condition and the body. 1202 ComputationBuilder builder(client_, TestName()); 1203 auto init = builder.ConstantR0<int32>(0); 1204 auto result = builder.While(condition, body, init); 1205 auto shape = builder.GetShape(result).ConsumeValueOrDie(); 1206 1207 ComputeAndCompareR0<int32>(&builder, 5, {}); 1208 } 1209 1210 TEST_F(WhileTest, WhileWithLoopInvariantOperation) { 1211 auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); 1212 auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); 1213 auto while_shape = ShapeUtil::MakeTupleShape( 1214 {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); 1215 1216 // Create a computation for the condition: repeat for 5 iterations. 1217 Computation condition; 1218 { 1219 ComputationBuilder builder(client_, "condition"); 1220 auto state = builder.Parameter(0, while_shape, "state"); 1221 builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0)); 1222 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); 1223 } 1224 1225 Computation body; 1226 { 1227 ComputationBuilder builder(client_, "body"); 1228 auto state = builder.Parameter(0, while_shape, "state"); 1229 auto indvar = builder.GetTupleElement(state, 0); 1230 auto input_0 = builder.GetTupleElement(state, 1); 1231 auto input_1 = builder.GetTupleElement(state, 2); 1232 auto output = builder.Tanh(builder.Dot(input_0, input_1)); 1233 auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1)); 1234 auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); 1235 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); 1236 } 1237 1238 ComputationBuilder builder(client_, TestName()); 1239 auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); 1240 auto init = builder.Tuple( 1241 {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input}); 1242 auto while_instruction = builder.While(condition, body, init); 1243 builder.GetTupleElement(while_instruction, 3); 1244 1245 TF_ASSERT_OK_AND_ASSIGN(auto param_value, 1246 client_->TransferToServer(*Literal::CreateR2<float>( 1247 {{1.0, 2.0}, {-1.0, -2.0}}))); 1248 1249 ComputeAndCompareR2<float>( 1250 &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, 1251 {param_value.get()}, ErrorSpec(4e-5)); 1252 } 1253 1254 void BM_WhileLoop(int num_iters) { 1255 // Benchmark a simple kernel to measure while loop overheads. 1256 tensorflow::testing::StopTiming(); 1257 1258 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); 1259 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); 1260 StreamExecutorMemoryAllocator allocator(platform, executors); 1261 LocalClient* client = 1262 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); 1263 1264 const int64 seq_len = 100; 1265 Shape loop_state_shape = ShapeUtil::MakeTupleShape( 1266 {ShapeUtil::MakeShape(S32, {}), 1267 ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})}); 1268 1269 // Create while condition computation with 'loop_limit'. 1270 const int32 loop_limit = 100; 1271 Computation condition; 1272 { 1273 ComputationBuilder builder(client, "condition"); 1274 auto prev = builder.Parameter(0, loop_state_shape, "prev"); 1275 auto iteration = builder.GetTupleElement(prev, 0); 1276 builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit)); 1277 condition = builder.Build().ConsumeValueOrDie(); 1278 } 1279 1280 // Create while body computation with unit loop increment. 1281 Computation body; 1282 { 1283 ComputationBuilder builder(client, "body"); 1284 auto prev = builder.Parameter(0, loop_state_shape, "prev"); 1285 // TupleElement 0 1286 auto iteration = builder.GetTupleElement(prev, 0); 1287 auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1)); 1288 // TupleElement 1 1289 auto input = builder.GetTupleElement(prev, 1); 1290 // Update. 1291 auto one = builder.ConstantR0<float>(1.0); 1292 auto update = builder.Broadcast(one, {1, 1024, 1024}); 1293 // Starts = iteration * 2; 1294 auto starts = builder.ConstantR1<int32>({0, 0, 0}); 1295 // UpdateSlice. 1296 auto out1 = builder.DynamicUpdateSlice(input, update, starts); 1297 auto result = builder.Tuple({out0, out1}); 1298 body = builder.Build().ConsumeValueOrDie(); 1299 } 1300 1301 // Create a While instruction. 1302 ComputationBuilder builder(client, "while"); 1303 auto zero = builder.ConstantR0<float>(0.0); 1304 auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); 1305 auto init = builder.Tuple({builder.ConstantR0<int32>(0), input}); 1306 builder.While(condition, body, init); 1307 auto computation = builder.Build().ConsumeValueOrDie(); 1308 1309 std::unique_ptr<LocalExecutable> executable = 1310 client->Compile(computation, {}, ExecutableBuildOptions()) 1311 .ConsumeValueOrDie(); 1312 1313 // Run some warm-up executions. 1314 ExecutableRunOptions options; 1315 options.set_allocator(&allocator); 1316 const int kWarmups = 2; 1317 for (int i = 0; i < kWarmups; ++i) { 1318 auto result = executable->Run({}, options); 1319 ASSERT_TRUE(result.ok()); 1320 } 1321 1322 // Run benchmark. 1323 tensorflow::testing::StartTiming(); 1324 for (int i = 0; i < num_iters; ++i) { 1325 auto result = executable->Run({}, options); 1326 ASSERT_TRUE(result.ok()); 1327 } 1328 } 1329 1330 // TODO(b/32470510): Benchmark fails on parallel CPU backend. 1331 #ifndef XLA_TEST_BACKEND_CPU_PARALLEL 1332 BENCHMARK(BM_WhileLoop); 1333 #endif 1334 1335 } // namespace 1336 } // namespace xla 1337