1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/xla_builder.h" 17 18 #include <string> 19 20 #include "tensorflow/compiler/xla/client/xla_computation.h" 21 #include "tensorflow/compiler/xla/debug_options_flags.h" 22 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 23 #include "tensorflow/compiler/xla/service/hlo_module.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/status_macros.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/test_helpers.h" 28 #include "tensorflow/compiler/xla/util.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 31 namespace xla { 32 33 namespace { 34 35 namespace op = xla::testing::opcode_matchers; 36 37 using ::testing::HasSubstr; 38 39 // TODO(b/74197823): Move the tests to service/. 40 class XlaBuilderTest : public ::testing::Test { 41 protected: 42 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) { 43 TF_ASSIGN_OR_RETURN(XlaComputation computation, 44 b->Build(/*remove_dynamic_dimensions=*/false)); 45 const HloModuleProto& proto = computation.proto(); 46 TF_ASSIGN_OR_RETURN(const auto& config, 47 HloModule::CreateModuleConfigFromProto( 48 proto, GetDebugOptionsFromFlags())); 49 return HloModule::CreateFromProto(proto, config); 50 } 51 52 // Overload which explicitly specifies the root instruction. 53 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b, 54 XlaOp root) { 55 TF_ASSIGN_OR_RETURN(XlaComputation computation, 56 b->Build(root, /*remove_dynamic_dimensions=*/false)); 57 const HloModuleProto& proto = computation.proto(); 58 TF_ASSIGN_OR_RETURN(const auto& config, 59 HloModule::CreateModuleConfigFromProto( 60 proto, GetDebugOptionsFromFlags())); 61 return HloModule::CreateFromProto(proto, config); 62 } 63 64 // Returns the name of the test currently being run. 65 string TestName() const { 66 return ::testing::UnitTest::GetInstance()->current_test_info()->name(); 67 } 68 }; 69 70 TEST_F(XlaBuilderTest, OnePlusTwo) { 71 XlaBuilder b(TestName()); 72 Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0)); 73 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 74 auto root = module->entry_computation()->root_instruction(); 75 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); 76 } 77 78 TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { 79 auto test_unary_operator = 80 [&](std::function<XlaOp(XlaOp)> op, 81 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { 82 XlaBuilder b(TestName()); 83 op(ConstantR0<int32>(&b, 1)); 84 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 85 auto root = module->entry_computation()->root_instruction(); 86 EXPECT_THAT(root, matches_pattern); 87 }; 88 test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); 89 test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); 90 } 91 92 TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { 93 auto test_binary_operator = 94 [&](std::function<XlaOp(XlaOp, XlaOp)> op, 95 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { 96 XlaBuilder b(TestName()); 97 op(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2)); 98 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 99 auto root = module->entry_computation()->root_instruction(); 100 EXPECT_THAT(root, matches_pattern); 101 }; 102 103 test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, 104 op::Add(op::Constant(), op::Constant())); 105 test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, 106 op::Subtract(op::Constant(), op::Constant())); 107 test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, 108 op::Multiply(op::Constant(), op::Constant())); 109 test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, 110 op::Divide(op::Constant(), op::Constant())); 111 112 test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, 113 op::And(op::Constant(), op::Constant())); 114 test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, 115 op::Or(op::Constant(), op::Constant())); 116 test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, 117 op::Xor(op::Constant(), op::Constant())); 118 test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, 119 op::ShiftLeft(op::Constant(), op::Constant())); 120 test_binary_operator( 121 [](XlaOp x, XlaOp y) { return x >> y; }, 122 op::ShiftRightArithmetic(op::Constant(), op::Constant())); 123 124 auto test_unsigned_binary_operator = 125 [&](std::function<XlaOp(XlaOp, XlaOp)> op, 126 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) { 127 XlaBuilder b(TestName()); 128 op(ConstantR0<uint32>(&b, 1), ConstantR0<uint32>(&b, 2)); 129 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 130 auto root = module->entry_computation()->root_instruction(); 131 EXPECT_THAT(root, matches_pattern); 132 }; 133 test_unsigned_binary_operator( 134 [](XlaOp x, XlaOp y) { return x >> y; }, 135 op::ShiftRightLogical(op::Constant(), op::Constant())); 136 } 137 138 TEST_F(XlaBuilderTest, VariadicAnd) { 139 XlaBuilder b(TestName()); 140 Shape s = ShapeUtil::MakeShape(PRED, {}); 141 And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), 142 Parameter(&b, 2, s, "p2")); 143 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 144 // Don't specify in the test whether And(x, y, z) is right- or 145 // left-associative; accept either one. 146 EXPECT_THAT( 147 module->entry_computation()->root_instruction(), 148 ::testing::AnyOf(op::And(op::Parameter(0), 149 op::And(op::Parameter(1), op::Parameter(2))), 150 op::And(op::And(op::Parameter(0), op::Parameter(1)), 151 op::Parameter(2)))); 152 } 153 154 TEST_F(XlaBuilderTest, VariadicOr) { 155 XlaBuilder b(TestName()); 156 Shape s = ShapeUtil::MakeShape(PRED, {}); 157 Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), 158 Parameter(&b, 2, s, "p2")); 159 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 160 // Don't specify in the test whether Or(x, y, z) is right- or 161 // left-associative; accept either one. 162 EXPECT_THAT( 163 module->entry_computation()->root_instruction(), 164 ::testing::AnyOf( 165 op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))), 166 op::Or(op::Or(op::Parameter(0), op::Parameter(1)), 167 op::Parameter(2)))); 168 } 169 170 TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { 171 XlaBuilder b(TestName()); 172 ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2); 173 auto statusor = b.Build(); 174 ASSERT_FALSE(statusor.ok()); 175 EXPECT_THAT( 176 statusor.status().error_message(), 177 HasSubstr("Argument to >> operator does not have an integral type")); 178 } 179 180 TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { 181 XlaBuilder b(TestName()); 182 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); 183 Add(x, ConstantR0<float>(&b, 1.0)); 184 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 185 auto root = module->entry_computation()->root_instruction(); 186 EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); 187 } 188 189 TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { 190 XlaBuilder b(TestName()); 191 const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); 192 const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); 193 auto x = Parameter(&b, 0, x_shape, "x"); 194 auto y = Parameter(&b, 1, y_shape, "y"); 195 auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1}); 196 197 TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); 198 EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); 199 200 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 201 auto root = module->entry_computation()->root_instruction(); 202 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); 203 } 204 205 TEST_F(XlaBuilderTest, XPlusX) { 206 XlaBuilder b(TestName()); 207 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); 208 Add(x, x); 209 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 210 auto root = module->entry_computation()->root_instruction(); 211 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); 212 } 213 214 TEST_F(XlaBuilderTest, ShapeInferenceError) { 215 XlaBuilder b(TestName()); 216 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); 217 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); 218 Add(x, y); 219 auto statusor = BuildHloModule(&b); 220 ASSERT_FALSE(statusor.ok()); 221 EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); 222 } 223 224 TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { 225 XlaBuilder b_call("add"); 226 Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x"); 227 228 XlaBuilder b(TestName()); 229 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x"); 230 auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y"); 231 Add(x, y); 232 auto statusor = BuildHloModule(&b); 233 ASSERT_FALSE(statusor.ok()); 234 EXPECT_THAT(statusor.status().error_message(), 235 HasSubstr("parameter 0 already registered")); 236 } 237 238 TEST_F(XlaBuilderTest, Call) { 239 XlaBuilder b_call("the_only_to_apply"); 240 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); 241 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); 242 Add(p0, p1); 243 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); 244 XlaBuilder b(TestName()); 245 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); 246 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); 247 auto one = ConstantR0<float>(&b, 1); 248 auto two = ConstantR0<float>(&b, 2); 249 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); 250 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 251 auto root = module->entry_computation()->root_instruction(); 252 EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), 253 op::Call(op::Constant(), op::Constant()))); 254 } 255 256 TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { 257 XlaBuilder b(TestName()); 258 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); 259 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); 260 Add(x, y); 261 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 262 263 // Expected: 264 // 265 // x: f32[1,2,3] y: f32[1,2,1] 266 // | | 267 // | reshape: f32[1,2] 268 // | | 269 // | broadcast: f32[1,2,3] 270 // \ / 271 // add 272 auto root = module->entry_computation()->root_instruction(); 273 EXPECT_THAT(root, op::Add(op::Parameter(0), 274 op::Broadcast(op::Reshape(op::Parameter(1))))); 275 } 276 277 TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { 278 XlaBuilder b(TestName()); 279 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); 280 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); 281 Add(x, y, /*broadcast_dimensions=*/{0, 1}); 282 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 283 284 // The binary operation has in-dim broadcast and degenerate broadcast, should 285 // first do the in-dim broadcast then convert the degnerate broadcast into a 286 // reshape and a broadcast. 287 // 288 // Expected: 289 // 290 // x: f32[2,3] y: f32[2,1,4] 291 // | | 292 // broadcast: f32[2,3,4] reshape: f32[2,4] 293 // | | 294 // | broadcast: f32[2,3,4] 295 // \ / 296 // add 297 auto root = module->entry_computation()->root_instruction(); 298 EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), 299 op::Broadcast(op::Reshape(op::Parameter(1))))); 300 } 301 302 TEST_F(XlaBuilderTest, BroadcastInDim) { 303 XlaBuilder b(TestName()); 304 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); 305 BroadcastInDim(x, {2, 4, 3}, 306 /*broadcast_dimensions=*/{0, 2}); 307 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 308 auto root = module->entry_computation()->root_instruction(); 309 EXPECT_THAT(root, op::Broadcast()); 310 } 311 312 TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { 313 XlaBuilder b(TestName()); 314 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); 315 BroadcastInDim(x, {2, 3, 4}, 316 /*broadcast_dimensions=*/{0, 1, 2}); 317 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 318 EXPECT_THAT(module->entry_computation()->root_instruction(), 319 op::Broadcast(op::Reshape(op::Broadcast()))); 320 } 321 322 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { 323 XlaBuilder b1("b1"); 324 auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); 325 XlaBuilder builder("main"); 326 auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p"); 327 Add(p, p0); 328 auto statusor = builder.Build(); 329 ASSERT_FALSE(statusor.ok()); 330 EXPECT_THAT( 331 statusor.status().error_message(), 332 HasSubstr( 333 "built by builder 'b1', but is trying to use it in builder 'main'")); 334 } 335 336 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { 337 XlaBuilder b(TestName()); 338 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); 339 Reshape(x, /*new_sizes=*/{6, 35}); 340 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 341 auto root = module->entry_computation()->root_instruction(); 342 EXPECT_THAT(root, op::Reshape(op::Parameter())); 343 } 344 345 TEST_F(XlaBuilderTest, ReshapeHasTranspose) { 346 XlaBuilder b(TestName()); 347 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); 348 Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); 349 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 350 auto root = module->entry_computation()->root_instruction(); 351 EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); 352 } 353 354 TEST_F(XlaBuilderTest, Transpose) { 355 XlaBuilder b(TestName()); 356 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); 357 Transpose(x, /*permutation=*/{1, 0}); 358 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 359 auto root = module->entry_computation()->root_instruction(); 360 EXPECT_THAT(root, op::Transpose(op::Parameter())); 361 } 362 363 TEST_F(XlaBuilderTest, AllToAll) { 364 XlaBuilder b(TestName()); 365 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); 366 AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, 367 /*split_count=*/2); 368 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 369 auto root = module->entry_computation()->root_instruction(); 370 371 // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. 372 EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); 373 EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); 374 EXPECT_TRUE( 375 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); 376 } 377 378 TEST_F(XlaBuilderTest, CollectivePermute) { 379 XlaBuilder b(TestName()); 380 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); 381 CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); 382 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 383 auto root = module->entry_computation()->root_instruction(); 384 EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); 385 } 386 387 TEST_F(XlaBuilderTest, GetDimensionSize) { 388 XlaBuilder b(TestName()); 389 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); 390 GetDimensionSize(x, 1); 391 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 392 auto root = module->entry_computation()->root_instruction(); 393 EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); 394 } 395 396 TEST_F(XlaBuilderTest, ReportError) { 397 XlaBuilder b(TestName()); 398 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); 399 Add(b.ReportError(InvalidArgument("a test error")), x); 400 auto statusor = b.Build(); 401 ASSERT_FALSE(statusor.ok()); 402 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); 403 } 404 405 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { 406 XlaBuilder b(TestName()); 407 StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0)); 408 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0)); 409 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 410 auto root = module->entry_computation()->root_instruction(); 411 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); 412 } 413 414 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { 415 XlaBuilder b(TestName()); 416 StatusOr<XlaOp> op(InvalidArgument("a test error")); 417 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0)); 418 auto statusor = b.Build(); 419 ASSERT_FALSE(statusor.ok()); 420 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); 421 } 422 423 TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { 424 XlaBuilder b(TestName()); 425 XlaOp constant = ConstantR0<float>(&b, 1.0); 426 Add(constant, ConstantR0<float>(&b, 2.0)); 427 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); 428 auto root = module->entry_computation()->root_instruction(); 429 EXPECT_THAT(root, op::Constant()); 430 } 431 432 TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { 433 // Specifying a particular root in Build should still include all entry 434 // parameters. 435 XlaBuilder b(TestName()); 436 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); 437 XlaOp x = Parameter(&b, 0, shape, "x"); 438 XlaOp y = Parameter(&b, 1, shape, "y"); 439 XlaOp z = Parameter(&b, 2, shape, "z"); 440 Add(x, Sub(y, z)); 441 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); 442 auto root = module->entry_computation()->root_instruction(); 443 EXPECT_THAT(root, op::Parameter()); 444 EXPECT_EQ(module->entry_computation()->num_parameters(), 3); 445 EXPECT_EQ(module->entry_computation()->instruction_count(), 5); 446 } 447 448 TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { 449 XlaBuilder b(TestName()); 450 XlaBuilder other_b(TestName()); 451 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); 452 453 Parameter(&b, 0, shape, "param"); 454 XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); 455 456 Status status = b.Build(other_param).status(); 457 ASSERT_IS_NOT_OK(status); 458 EXPECT_THAT( 459 status.error_message(), 460 ::testing::HasSubstr("root operation is not in this computation")); 461 } 462 463 TEST_F(XlaBuilderTest, ProtoMatches) { 464 std::vector<XlaComputation> computations; 465 for (int i = 0; i < 2; ++i) { 466 XlaBuilder b_call("the_only_to_apply"); 467 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); 468 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); 469 Add(p0, Add(p1, p0)); 470 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); 471 XlaBuilder b(TestName()); 472 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); 473 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); 474 auto one = ConstantR0<float>(&b, 1); 475 auto two = ConstantR0<float>(&b, 2); 476 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); 477 computations.push_back(b.Build().ValueOrDie()); 478 } 479 auto c0_string = computations[0].proto().SerializeAsString(); 480 auto c1_string = computations[1].proto().SerializeAsString(); 481 EXPECT_EQ(c0_string, c1_string); 482 } 483 484 TEST_F(XlaBuilderTest, DynamicParameter) { 485 XlaBuilder b(TestName()); 486 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 487 {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})}); 488 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 489 Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); 490 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, 491 /*dynamic_size_param_index=*/{}, 492 /*target_param_num=*/0, 493 /*target_param_index=*/{1}, 494 /*target_dim_num=*/0)); 495 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); 496 const Shape& param_shape = module->entry_computation() 497 ->parameter_instruction(0) 498 ->shape() 499 .tuple_shapes(1); 500 EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); 501 } 502 503 TEST_F(XlaBuilderTest, DynamicUnary) { 504 XlaBuilder b(TestName()); 505 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 506 {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); 507 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 508 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 509 /*dynamic_size_param_index=*/{1}, 510 /*target_param_num=*/0, 511 /*target_param_index=*/{0}, 512 /*target_dim_num=*/0)); 513 auto gte = GetTupleElement(p0, 0); 514 Neg(gte); 515 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 516 const Shape& result_shape = 517 module->entry_computation()->root_instruction()->shape(); 518 EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); 519 } 520 521 TEST_F(XlaBuilderTest, DynamicBinary) { 522 XlaBuilder b(TestName()); 523 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 524 {ShapeUtil::MakeShape(F32, {5}, {true}), 525 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); 526 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 527 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 528 /*dynamic_size_param_index=*/{2}, 529 /*target_param_num=*/0, 530 /*target_param_index=*/{0}, 531 /*target_dim_num=*/0)); 532 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 533 /*dynamic_size_param_index=*/{2}, 534 /*target_param_num=*/0, 535 /*target_param_index=*/{1}, 536 /*target_dim_num=*/0)); 537 auto gte0 = GetTupleElement(p0, 0); 538 auto gte1 = GetTupleElement(p0, 1); 539 Add(gte0, gte1); 540 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 541 const Shape& result_shape = 542 module->entry_computation()->root_instruction()->shape(); 543 EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); 544 } 545 546 TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { 547 XlaBuilder b(TestName()); 548 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 549 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), 550 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); 551 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 552 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 553 /*dynamic_size_param_index=*/{2}, 554 /*target_param_num=*/0, 555 /*target_param_index=*/{0}, 556 /*target_dim_num=*/0)); 557 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 558 /*dynamic_size_param_index=*/{2}, 559 /*target_param_num=*/0, 560 /*target_param_index=*/{1}, 561 /*target_dim_num=*/0)); 562 auto gte0 = GetTupleElement(p0, 0); 563 auto gte1 = GetTupleElement(p0, 1); 564 Add(gte0, gte1, {0}); 565 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 566 const Shape& result_shape = 567 module->entry_computation()->root_instruction()->shape(); 568 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) 569 << result_shape; 570 } 571 572 TEST_F(XlaBuilderTest, DynamicBroadcast) { 573 XlaBuilder b(TestName()); 574 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 575 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), 576 ShapeUtil::MakeShape(U32, {})}); 577 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 578 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 579 /*dynamic_size_param_index=*/{1}, 580 /*target_param_num=*/0, 581 /*target_param_index=*/{0}, 582 /*target_dim_num=*/0)); 583 auto gte = GetTupleElement(p0, 0); 584 BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, 585 /*broadcast_dimensions=*/{1, 2}); 586 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 587 const Shape& result_shape = 588 module->entry_computation()->root_instruction()->shape(); 589 EXPECT_TRUE( 590 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) 591 << result_shape; 592 } 593 594 TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { 595 XlaBuilder b(TestName()); 596 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 597 {ShapeUtil::MakeShape(F32, {10}, {true}), 598 ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})}); 599 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 600 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 601 /*dynamic_size_param_index=*/{1}, 602 /*target_param_num=*/0, 603 /*target_param_index=*/{0}, 604 /*target_dim_num=*/0)); 605 auto gte0 = GetTupleElement(p0, 0); 606 auto gte1 = GetTupleElement(p0, 1); 607 Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15] 608 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 609 const Shape& result_shape = 610 module->entry_computation()->root_instruction()->shape(); 611 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) 612 << result_shape; 613 } 614 615 TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { 616 XlaBuilder b(TestName()); 617 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 618 {ShapeUtil::MakeShape(PRED, {10}, {true}), 619 ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})}); 620 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 621 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 622 /*dynamic_size_param_index=*/{1}, 623 /*target_param_num=*/0, 624 /*target_param_index=*/{0}, 625 /*target_dim_num=*/0)); 626 auto gte0 = GetTupleElement(p0, 0); 627 auto gte1 = GetTupleElement(p0, 1); 628 629 Select(gte0, gte1, gte1); 630 631 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 632 const Shape& result_shape = 633 module->entry_computation()->root_instruction()->shape(); 634 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true})) 635 << result_shape; 636 } 637 638 TEST_F(XlaBuilderTest, DynamicPad) { 639 XlaBuilder b(TestName()); 640 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 641 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), 642 ShapeUtil::MakeShape(U32, {})}); 643 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 644 auto pad_val = ConstantR0<float>(&b, -1); 645 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 646 /*dynamic_size_param_index=*/{1}, 647 /*target_param_num=*/0, 648 /*target_param_index=*/{0}, 649 /*target_dim_num=*/0)); 650 auto gte = GetTupleElement(p0, 0); 651 PaddingConfig padding_config; 652 for (int i = 0; i < 2; i++) { 653 auto dimension = padding_config.add_dimensions(); 654 dimension->set_edge_padding_low(0); 655 dimension->set_edge_padding_high(0); 656 dimension->set_interior_padding(0); 657 } 658 Pad(gte, pad_val, padding_config); 659 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 660 const Shape& result_shape = 661 module->entry_computation()->root_instruction()->shape(); 662 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) 663 << result_shape; 664 } 665 666 TEST_F(XlaBuilderTest, DynamicConvolution) { 667 XlaBuilder b(TestName()); 668 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 669 {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}), 670 ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}), 671 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); 672 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 673 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 674 /*dynamic_size_param_index=*/{2}, 675 /*target_param_num=*/0, 676 /*target_param_index=*/{0}, 677 /*target_dim_num=*/0)); 678 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 679 /*dynamic_size_param_index=*/{3}, 680 /*target_param_num=*/0, 681 /*target_param_index=*/{1}, 682 /*target_dim_num=*/2)); 683 auto input = GetTupleElement(p0, 0); 684 auto filter = GetTupleElement(p0, 1); 685 ConvolutionDimensionNumbers dnums; 686 dnums.set_input_batch_dimension(0); 687 dnums.set_output_batch_dimension(0); 688 dnums.add_input_spatial_dimensions(1); 689 dnums.add_output_spatial_dimensions(1); 690 dnums.add_input_spatial_dimensions(2); 691 dnums.add_output_spatial_dimensions(2); 692 dnums.set_input_feature_dimension(3); 693 dnums.set_output_feature_dimension(3); 694 dnums.add_kernel_spatial_dimensions(0); 695 dnums.add_kernel_spatial_dimensions(1); 696 dnums.set_kernel_input_feature_dimension(2); 697 dnums.set_kernel_output_feature_dimension(3); 698 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, 699 /*feature_group_count=*/1); 700 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 701 const Shape& result_shape = 702 module->entry_computation()->root_instruction()->shape(); 703 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), 704 {true, false, false, false})) 705 << result_shape; 706 } 707 708 TEST_F(XlaBuilderTest, DynamicDot) { 709 XlaBuilder b(TestName()); 710 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 711 {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}), 712 ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}), 713 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); 714 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 715 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 716 /*dynamic_size_param_index=*/{2}, 717 /*target_param_num=*/0, 718 /*target_param_index=*/{0}, 719 /*target_dim_num=*/0)); 720 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 721 /*dynamic_size_param_index=*/{2}, 722 /*target_param_num=*/0, 723 /*target_param_index=*/{1}, 724 /*target_dim_num=*/0)); 725 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 726 /*dynamic_size_param_index=*/{3}, 727 /*target_param_num=*/0, 728 /*target_param_index=*/{0}, 729 /*target_dim_num=*/1)); 730 731 auto lhs = GetTupleElement(p0, 0); 732 auto rhs = GetTupleElement(p0, 1); 733 DotDimensionNumbers dnums; 734 dnums.add_lhs_contracting_dimensions(2); 735 dnums.add_rhs_contracting_dimensions(1); 736 dnums.add_lhs_batch_dimensions(0); 737 dnums.add_rhs_batch_dimensions(0); 738 DotGeneral(lhs, rhs, dnums); 739 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 740 const Shape& result_shape = 741 module->entry_computation()->root_instruction()->shape(); 742 EXPECT_TRUE( 743 ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false})) 744 << result_shape; 745 } 746 747 TEST_F(XlaBuilderTest, DynamicReduce) { 748 XlaBuilder b(TestName()); 749 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 750 {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}), 751 ShapeUtil::MakeShape(U32, {})}); 752 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 753 auto init = ConstantR0<float>(&b, 0); 754 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 755 /*dynamic_size_param_index=*/{1}, 756 /*target_param_num=*/0, 757 /*target_param_index=*/{0}, 758 /*target_dim_num=*/1)); 759 auto gte = GetTupleElement(p0, 0); 760 XlaBuilder bsum(TestName()); 761 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), 762 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); 763 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); 764 Reduce(gte, init, sum, {0}); 765 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 766 const Shape& result_shape = 767 module->entry_computation()->root_instruction()->shape(); 768 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) 769 << result_shape; 770 } 771 772 TEST_F(XlaBuilderTest, DynamicReduceWindow) { 773 XlaBuilder b(TestName()); 774 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 775 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), 776 ShapeUtil::MakeShape(U32, {})}); 777 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 778 auto init = ConstantR0<float>(&b, 0.f); 779 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 780 /*dynamic_size_param_index=*/{1}, 781 /*target_param_num=*/0, 782 /*target_param_index=*/{0}, 783 /*target_dim_num=*/0)); 784 auto gte = GetTupleElement(p0, 0); 785 XlaBuilder bsum(TestName()); 786 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), 787 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); 788 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); 789 ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, 790 /*window_strides=*/{1, 1, 1}, Padding::kValid); 791 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 792 const Shape& result_shape = 793 module->entry_computation()->root_instruction()->shape(); 794 EXPECT_TRUE( 795 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) 796 << result_shape; 797 } 798 799 TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { 800 XlaBuilder b(TestName()); 801 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 802 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), 803 ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}), 804 ShapeUtil::MakeShape(U32, {})}); 805 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 806 auto init = ConstantR0<float>(&b, 0.f); 807 XlaBuilder bsum(TestName()); 808 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), 809 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); 810 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); 811 XlaBuilder bge(TestName()); 812 Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), 813 Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); 814 TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); 815 816 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 817 /*dynamic_size_param_index=*/{2}, 818 /*target_param_num=*/0, 819 /*target_param_index=*/{0}, 820 /*target_dim_num=*/0)); 821 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 822 /*dynamic_size_param_index=*/{2}, 823 /*target_param_num=*/0, 824 /*target_param_index=*/{1}, 825 /*target_dim_num=*/0)); 826 auto gte0 = GetTupleElement(p0, 0); 827 auto source = GetTupleElement(p0, 1); 828 SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, 829 init, sum); 830 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 831 const Shape& result_shape = 832 module->entry_computation()->root_instruction()->shape(); 833 EXPECT_TRUE( 834 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) 835 << result_shape; 836 } 837 838 TEST_F(XlaBuilderTest, DynamicReshape) { 839 XlaBuilder b(TestName()); 840 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 841 {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}, 842 {false, false, true, true, false}), 843 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); 844 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 845 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 846 /*dynamic_size_param_index=*/{1}, 847 /*target_param_num=*/0, 848 /*target_param_index=*/{0}, 849 /*target_dim_num=*/2)); 850 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 851 /*dynamic_size_param_index=*/{2}, 852 /*target_param_num=*/0, 853 /*target_param_index=*/{0}, 854 /*target_dim_num=*/3)); 855 auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] 856 Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3}); 857 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 858 const Shape& result_shape = 859 module->entry_computation()->root_instruction()->shape(); 860 EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); 861 EXPECT_TRUE(result_shape.is_dynamic_dimension(3)); 862 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), 863 {false, true, false, true, false, false})) 864 << result_shape; 865 } 866 867 TEST_F(XlaBuilderTest, DynamicSelect) { 868 XlaBuilder b(TestName()); 869 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 870 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), 871 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), 872 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); 873 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 874 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); 875 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 876 /*dynamic_size_param_index=*/{2}, 877 /*target_param_num=*/0, 878 /*target_param_index=*/{0}, 879 /*target_dim_num=*/1)); 880 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 881 /*dynamic_size_param_index=*/{3}, 882 /*target_param_num=*/0, 883 /*target_param_index=*/{1}, 884 /*target_dim_num=*/1)); 885 auto gte0 = GetTupleElement(p0, 0); 886 auto gte1 = GetTupleElement(p0, 1); 887 Select(pred, gte0, gte1); 888 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 889 const Shape& result_shape = 890 module->entry_computation()->root_instruction()->shape(); 891 EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); 892 EXPECT_FALSE(result_shape.is_dynamic_dimension(2)); 893 EXPECT_TRUE( 894 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) 895 << result_shape; 896 } 897 898 TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { 899 XlaBuilder b(TestName()); 900 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 901 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), 902 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}), 903 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); 904 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 905 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); 906 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 907 /*dynamic_size_param_index=*/{2}, 908 /*target_param_num=*/0, 909 /*target_param_index=*/{0}, 910 /*target_dim_num=*/1)); 911 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 912 /*dynamic_size_param_index=*/{3}, 913 /*target_param_num=*/0, 914 /*target_param_index=*/{1}, 915 /*target_dim_num=*/2)); 916 auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] 917 auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] 918 Select(pred, gte0, gte1); 919 Status status = BuildHloModule(&b).status(); 920 ASSERT_IS_NOT_OK(status); 921 EXPECT_THAT(status.error_message(), 922 ::testing::HasSubstr("Operands to select must be the same shape; " 923 "got f32[4,<=5,6] and f32[4,5,<=6]")); 924 } 925 926 TEST_F(XlaBuilderTest, DynamicTranspose) { 927 XlaBuilder b(TestName()); 928 Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 929 {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}), 930 ShapeUtil::MakeShape(U32, {})}); 931 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); 932 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, 933 /*dynamic_size_param_index=*/{1}, 934 /*target_param_num=*/0, 935 /*target_param_index=*/{0}, 936 /*target_dim_num=*/0)); 937 auto gte = GetTupleElement(p0, 0); 938 Transpose(gte, /*permutation=*/{1, 0}); 939 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); 940 const Shape& result_shape = 941 module->entry_computation()->root_instruction()->shape(); 942 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) 943 << result_shape; 944 } 945 946 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { 947 XlaBuilder b(TestName()); 948 AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)}); 949 Status status = b.Build().status(); 950 ASSERT_IS_NOT_OK(status); 951 EXPECT_THAT(status.error_message(), 952 ::testing::HasSubstr("All operands to AfterAll must be tokens")); 953 } 954 955 TEST_F(XlaBuilderTest, CheckInputOutputAlias) { 956 XlaBuilder b(TestName()); 957 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); 958 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); 959 auto add = Add(p0, p1); 960 auto sub = Sub(p0, p1); 961 auto root = Tuple(&b, {add, sub}); 962 963 b.SetUpAlias({1}, 0, {}); 964 b.SetUpAlias({0}, 1, {}); 965 966 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); 967 968 const HloInputOutputAliasConfig& config = module->input_output_alias_config(); 969 EXPECT_TRUE(config.ParameterHasAlias(0, {})); 970 EXPECT_TRUE(config.ParameterHasAlias(1, {})); 971 972 auto alias_p0 = config.GetAliasedOutput(0, {}); 973 ASSERT_TRUE(alias_p0.has_value()); 974 EXPECT_EQ(*alias_p0, ShapeIndex({1})); 975 976 auto alias_p1 = config.GetAliasedOutput(1, {}); 977 ASSERT_TRUE(alias_p1.has_value()); 978 EXPECT_EQ(*alias_p1, ShapeIndex({0})); 979 } 980 981 } // namespace 982 } // namespace xla 983