1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" 17 18 #include <string> 19 #include "tensorflow/core/lib/core/status_test_util.h" 20 #include "tensorflow/core/lib/core/stringpiece.h" 21 #include "tensorflow/core/platform/test.h" 22 23 namespace xla { 24 namespace tools { 25 namespace { 26 27 using tensorflow::StringPiece; 28 29 struct TestData { 30 string test_name; 31 string module_string; 32 }; 33 34 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) { 35 return data.param.test_name; 36 } 37 38 // For each string below, we check that: 39 // - we parse it to an HloModule successfully, and 40 // - the stringification of the resulting HloModule is equal to our original 41 // string. 42 std::vector<TestData> CreateTestCases() { 43 // clang-format off 44 return std::vector<TestData>({ 45 // ax + y 46 { 47 "AxpyParam", 48 R"(HloModule axpy_module 49 50 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { 51 %alpha = f32[] parameter(0) 52 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} 53 %x = f32[2,4]{1,0} parameter(1) 54 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) 55 %y = f32[2,4]{1,0} parameter(2) 56 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) 57 } 58 59 )" 60 }, 61 // pred constant 62 { 63 "ConstantPred", 64 R"(HloModule constant_pred_module 65 66 ENTRY %constant_pred () -> pred[] { 67 ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} 68 } 69 70 )" 71 }, 72 // s32 constant 73 { 74 "ConstantS32", 75 R"(HloModule constant_s32_module 76 77 ENTRY %constant_s32 () -> s32[] { 78 ROOT %constant = s32[] constant(-42) 79 } 80 81 )" 82 }, 83 // f32 constant, but the value is not a decimal 84 { 85 "ConstantF32", 86 R"(HloModule ConstantF32_module 87 88 ENTRY %ConstantF32.v4 () -> f32[] { 89 ROOT %constant = f32[] constant(42) 90 } 91 92 )" 93 }, 94 // f32 constant, rank 1 empty array. 95 { 96 "ConstantF32R1Empty", 97 R"(HloModule ConstantF32Empty_module 98 99 ENTRY %ConstantF32Empty.v4 () -> f32[0] { 100 ROOT %constant = f32[0]{0} constant({}) 101 } 102 103 )" 104 }, 105 // f32 constant, rank 4 empty array. 106 { 107 "ConstantF32R4Empty", 108 R"(HloModule ConstantF32R4Empty_module 109 110 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { 111 ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) 112 } 113 114 )" 115 }, 116 // constant 4D 117 { 118 "Constant4D", 119 R"(HloModule Small_3x2x1x1_module 120 121 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { 122 ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) 123 } 124 125 )" 126 }, 127 // non-finite constants: nan, inf, -inf 128 { 129 "ConstantNonFinite", 130 R"(HloModule IsFiniteR1F32s_module 131 132 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { 133 %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) 134 ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant) 135 } 136 137 )" 138 }, 139 // constant f16 140 { 141 "ConstantF16", 142 R"(HloModule ConstantF16_module 143 144 ENTRY %ConstantF16.v4 () -> f16[] { 145 ROOT %constant = f16[] constant(500) 146 } 147 148 )" 149 }, 150 // bf16 151 { 152 "BF16", 153 R"(HloModule BF16 154 155 ENTRY %BF16.v4 () -> bf16[] { 156 ROOT %constant = bf16[] constant(500) 157 } 158 159 )" 160 }, 161 // constant + constant 162 { 163 "AddConstants", 164 R"(HloModule add_constants_module 165 166 ENTRY %add_constants () -> f32[] { 167 %constant = f32[] constant(3.14) 168 ROOT %add = f32[] add(f32[] %constant, f32[] %constant) 169 } 170 171 )" 172 }, 173 // tuple constant 174 { 175 "TupleConstant", 176 R"(HloModule TupleConstant_module 177 178 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { 179 ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) 180 } 181 182 )" 183 }, 184 // v1 > v2 ? v1 : v2 185 { 186 "SelectR1F32", 187 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module 188 189 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { 190 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} 191 %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} 192 %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} 193 ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} 194 } 195 196 )" 197 }, 198 // empty tuple 199 { 200 "EmptyTupleCreate", 201 R"(HloModule EmptyTupleCreate_module 202 203 ENTRY %EmptyTupleCreate.v1 () -> () { 204 ROOT %tuple = () tuple() 205 } 206 207 )" 208 }, 209 // tuple 210 { 211 "TupleCreate", 212 R"(HloModule TupleCreate_module 213 214 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { 215 %v1 = f32[] parameter(0) 216 %v2 = f32[3]{0} parameter(1) 217 %v3 = f32[2,3]{1,0} parameter(2) 218 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) 219 } 220 221 )" 222 }, 223 { 224 "ShardedTupleCreate", 225 R"(HloModule ShardedTupleCreate_module 226 227 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { 228 %v1 = f32[] parameter(0) 229 %v2 = f32[3]{0} parameter(1) 230 %v3 = f32[2,3]{1,0} parameter(2) 231 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} 232 } 233 234 )" 235 }, 236 // int32 result = 0; 237 // while (result < 5) { result = result + 1; } 238 { 239 "WhileWithScalarS32Result", 240 R"(HloModule WhileWithScalarS32Result_module 241 242 %body.v3 (prev.1: s32[]) -> s32[] { 243 %constant = s32[] constant(1) 244 %prev.1 = s32[] parameter(0) 245 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1) 246 } 247 248 %condition.v3 (prev.2: s32[]) -> pred[] { 249 %constant.1 = s32[] constant(5) 250 %prev.2 = s32[] parameter(0) 251 ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) 252 } 253 254 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { 255 %constant.2 = s32[] constant(0) 256 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3 257 } 258 259 )" 260 }, 261 // send and recv 262 { 263 "SendRecv", 264 R"(HloModule TwoSendRecvBothWayRecvFist_module 265 266 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { 267 %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} 268 ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} 269 %constant = f32[] constant(2.1), sharding={maximal device=0} 270 %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} 271 %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} 272 } 273 274 )" 275 }, 276 // get-tuple-element 277 { 278 "GetTupleElement", 279 R"(HloModule GetTupleElement_module 280 281 ENTRY %GetTupleElement.v4 () -> s32[2,3] { 282 %constant = f32[3]{0} constant({1, 2, 3}) 283 %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) 284 %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) 285 ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} 286 } 287 288 )" 289 }, 290 // call 291 { 292 "Call", 293 R"(HloModule CallR0F32IdentityScalar_module 294 295 %Identity.v1 (x: f32[]) -> f32[] { 296 ROOT %x = f32[] parameter(0) 297 } 298 299 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { 300 %constant = f32[] constant(42) 301 ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 302 } 303 304 )" 305 }, 306 // reduce window 307 { 308 "ReduceWindow", 309 R"(HloModule R4UnitWindow_module 310 311 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { 312 %lhs = f32[] parameter(0) 313 %rhs = f32[] parameter(1) 314 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) 315 } 316 317 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { 318 %operand = f32[13,12,8,15]{0,3,2,1} parameter(0) 319 %constant = f32[] constant(0) 320 ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 321 } 322 323 )" 324 }, 325 // reduce window on scalar 326 { 327 "ReduceWindowScalar", 328 R"(HloModule reduce_window_scalar 329 330 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { 331 %lhs = f32[] parameter(0) 332 %rhs = f32[] parameter(1) 333 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) 334 } 335 336 ENTRY %R4UnitWindowScalar () -> f32[] { 337 %constant = f32[] constant(42) 338 %constant.1 = f32[] constant(1) 339 ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3 340 } 341 342 )" 343 }, 344 // convolution 345 { 346 "Convolution", 347 R"(HloModule Convolve1D1Window_0_module 348 349 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { 350 %input = f32[1,2,1]{2,1,0} parameter(0) 351 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) 352 %filter = f32[1,1,1]{2,1,0} parameter(1) 353 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f 354 } 355 356 )" 357 }, 358 // convolution rank 2 359 { 360 "ConvolutionR2", 361 R"(HloModule ConvolveR2_module 362 363 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { 364 %input = f32[1,2]{1,0} parameter(0) 365 %filter = f32[1,1]{1,0} parameter(1) 366 ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf 367 } 368 369 )" 370 }, 371 // convolution backward 372 { 373 "ConvolutionBackward", 374 R"(HloModule ConvolveBackward_module 375 376 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { 377 %input = f32[128,7,7,512]{0,3,2,1} parameter(0) 378 %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) 379 ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f 380 } 381 382 )" 383 }, 384 // reverse(constant) 385 { 386 "Reverse4D", 387 R"(HloModule Reverse4DFloatArrayOnDim01_module 388 389 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { 390 %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) 391 ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} 392 } 393 394 )" 395 }, 396 // concat 397 { 398 "Concat", 399 R"(HloModule Concat2x3With2x5_module 400 401 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { 402 %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) 403 %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) 404 ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} 405 } 406 407 )" 408 }, 409 // select and scatter 410 { 411 "SelectAndScatter", 412 R"(HloModule R4F32OverlapSmall_module 413 414 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { 415 %lhs = f32[] parameter(0) 416 %rhs = f32[] parameter(1) 417 ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) 418 } 419 420 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { 421 %lhs.1 = f32[] parameter(0) 422 %rhs.1 = f32[] parameter(1) 423 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) 424 } 425 426 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { 427 %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) 428 %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) 429 %constant.2 = f32[] constant(0) 430 ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 431 } 432 433 )" 434 }, 435 // select and scatter on scalar 436 { 437 "SelectAndScatterScalar", 438 R"(HloModule select_and_scatter_scalar 439 440 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { 441 %lhs = f32[] parameter(0) 442 %rhs = f32[] parameter(1) 443 ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) 444 } 445 446 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { 447 %lhs.1 = f32[] parameter(0) 448 %rhs.1 = f32[] parameter(1) 449 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) 450 } 451 452 ENTRY %SelectAndScatterScalar () -> f32[] { 453 %constant = f32[] constant(42) 454 %constant.1 = f32[] constant(1) 455 %constant.2 = f32[] constant(2) 456 ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3 457 } 458 459 )" 460 }, 461 // slice 462 { 463 "Slice", 464 R"(HloModule slice_module 465 466 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { 467 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) 468 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]} 469 } 470 471 )" 472 }, 473 // slice, no stride 474 { 475 "SliceNoStride", 476 R"(HloModule Slice3x3x3_To_1x3x3_F32_module 477 478 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { 479 %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) 480 ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} 481 } 482 483 )" 484 }, 485 // slice R0 486 { 487 "SliceR0", 488 R"(HloModule SliceR0_module 489 490 ENTRY %SliceR0.v2 () -> s32[] { 491 %constant = s32[] constant(1) 492 ROOT %slice = s32[] slice(s32[] %constant), slice={} 493 } 494 495 )" 496 }, 497 // transpose 498 { 499 "Transpose", 500 R"(HloModule Transpose_module 501 502 ENTRY %Transpose.v2 () -> s32[1,2,3] { 503 %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) 504 ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} 505 } 506 507 )" 508 }, 509 // Dynamic slice 510 { 511 "DynamicSlice", 512 R"(HloModule DynamicSlice_module 513 514 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { 515 %original_parameter = s32[2,2,258]{2,1,0} parameter(0) 516 %constant = s32[1]{0} constant({0}) 517 %start_index = s32[1]{0} parameter(1) 518 %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0} 519 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} 520 } 521 522 )" 523 }, 524 // Dynamic update slice 525 { 526 "DynamicUpdateSlice", 527 R"(HloModule DynamicUpdateSlice_module 528 529 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { 530 %input = s32[1,1,25,1]{3,2,1,0} parameter(0) 531 %update = s32[1,1,2,1]{3,2,1,0} parameter(1) 532 %start_indices = s32[4]{0} parameter(2) 533 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) 534 } 535 536 )" 537 }, 538 // batch norm training 539 { 540 "BatchNormTraining", 541 R"(HloModule BasicTraining_module 542 543 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { 544 %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) 545 %constant.1 = f32[2]{0} constant({2, 3}) 546 %constant.2 = f32[2]{0} constant({1, 2}) 547 ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 548 } 549 550 )" 551 }, 552 // batch norm inference 553 { 554 "BatchNormInference", 555 R"(HloModule BatchNormInference_module 556 557 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { 558 %input = f32[2,2,2,2]{3,2,1,0} parameter(0) 559 %offset = f32[2]{0} parameter(1) 560 %scale = f32[2]{0} parameter(2) 561 %mean = f32[2]{0} parameter(3) 562 %variance = f32[2]{0} parameter(4) 563 ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0 564 } 565 566 )" 567 }, 568 // batch norm grad 569 { 570 "BatchNormGrad", 571 R"(HloModule BatchNormGrad_module 572 573 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { 574 %input = f32[2,2,2,2]{3,2,1,0} parameter(0) 575 %scale = f32[2]{0} parameter(1) 576 %mean = f32[2]{0} parameter(2) 577 %variance = f32[2]{0} parameter(3) 578 %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4) 579 ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 580 } 581 582 )" 583 }, 584 // fft 585 { 586 "Fft", 587 R"(HloModule Fft_module 588 589 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] { 590 %input = c64[8,32]{1,0} parameter(0) 591 ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32} 592 } 593 594 )" 595 }, 596 // ifft 597 { 598 "Ifft2d", 599 R"(HloModule Ifft2d_module 600 601 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] { 602 %input = c64[5,8,32]{2,1,0} parameter(0) 603 ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32} 604 } 605 606 )" 607 }, 608 // rfft2d 609 { 610 "Rfft2d", 611 R"(HloModule Rfft2d_module 612 613 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] { 614 %input = f32[5,64,32]{2,1,0} parameter(0) 615 ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32} 616 } 617 618 )" 619 }, 620 // irfft3d 621 { 622 "Irfft3d", 623 R"(HloModule Irfft3d_module 624 625 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] { 626 %input = c64[5,64,128,33]{3,2,1,0} parameter(0) 627 ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64} 628 } 629 630 )" 631 }, 632 // pad 633 { 634 "Pad", 635 R"(HloModule Pad1DS3Array_module 636 637 ENTRY %Pad1DS3Array.v3 () -> f32[8] { 638 %constant = f32[3]{0} constant({1, 2, 3}) 639 %constant.1 = f32[] constant(0.1) 640 ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1 641 } 642 643 )" 644 }, 645 // pad has interior 646 { 647 "PadHasInterior", 648 R"(HloModule PadHasInterior_module 649 650 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { 651 %input = f32[1,25,7,7]{3,2,1,0} parameter(0) 652 %constant = f32[] constant(-5.123) 653 ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 654 } 655 656 )" 657 }, 658 // Negative padding 659 { 660 "PadHasNegativePadding", 661 R"(HloModule PadHasNegativePadding_module 662 663 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] { 664 %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0) 665 %constant = f32[] constant(-5.123) 666 ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3 667 } 668 669 )" 670 }, 671 // fusion 672 { 673 "Fusion", 674 R"(HloModule fusion_module 675 676 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { 677 %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) 678 %constant.1.param_1 = f32[2]{0} parameter(1) 679 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1} 680 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) 681 } 682 683 ENTRY %fusion.v3 () -> f32[3,2,1,1] { 684 %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) 685 %constant.1 = f32[2]{0} constant({3.14, 4.25}) 686 ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation 687 } 688 689 )" 690 }, 691 { 692 "Sparse", 693 R"(HloModule sparse_f32 694 695 ENTRY %sparse () -> f32[2,3,4] { 696 ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) 697 } 698 699 )" 700 }, 701 { 702 "SparseEmpty", 703 R"(HloModule sparse_f32_empty 704 705 ENTRY %sparse_f32_empty () -> f32[2,3,4] { 706 ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) 707 } 708 709 )" 710 }, 711 { 712 "SparseR1", 713 R"(HloModule sparse_f32_r1 714 715 ENTRY %sparse_f32_r1 () -> f32[9] { 716 ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) 717 } 718 719 )" 720 }, 721 }); 722 // clang-format on 723 } 724 725 std::vector<TestData> CreateShortTestCases() { 726 // clang-format off 727 return std::vector<TestData>({ 728 // map 729 { 730 "Map", 731 R"(HloModule MapBinaryAdder_module 732 733 add_F32.v3 { 734 lhs = f32[] parameter(0) 735 rhs = f32[] parameter(1) 736 ROOT add = f32[] add(lhs, rhs) 737 } 738 739 ENTRY MapBinaryAdder.v3 { 740 param0 = f32[4]{0} parameter(0) 741 param1 = f32[4]{0} parameter(1) 742 ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 743 } 744 745 )" 746 }, 747 // reduce 748 { 749 "Reduce", 750 R"(HloModule ReduceR3ToR2_module 751 752 add_F32.v3 { 753 lhs = f32[] parameter(0) 754 rhs = f32[] parameter(1) 755 ROOT add = f32[] add(lhs, rhs) 756 } 757 758 ENTRY ReduceR3ToR2.v3 { 759 input = f32[8,16,256]{2,1,0} parameter(0) 760 constant = f32[] constant(0) 761 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 762 } 763 764 )" 765 }, 766 // infeed/outfeed 767 { 768 "InfeedOutfeed", 769 R"(HloModule outfeed_module 770 771 ENTRY InfeedToOutfeed { 772 infeed = (u32[3]{0}, pred[]) infeed() 773 outfeed = () outfeed(infeed) 774 ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() 775 outfeed.1 = () outfeed(infeed.1) 776 } 777 778 )" 779 }, 780 // Rng 781 { 782 "Rng", 783 R"(HloModule rng_module 784 785 ENTRY Rng { 786 constant = f32[] constant(0) 787 constant.1 = f32[] constant(1) 788 ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform 789 } 790 791 )" 792 }, 793 // Reduce precision 794 { 795 "ReducePrevison", 796 R"(HloModule reduce_precision 797 798 ENTRY ReducePrecision { 799 constant = f32[1]{0} constant({3.14159}) 800 ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 801 } 802 803 )" 804 }, 805 // Conditional 806 { 807 "Conditional", 808 R"(HloModule conditional 809 810 Negate { 811 x = f32[] parameter(0) 812 ROOT negate = f32[] negate(x) 813 } 814 815 Identity { 816 y = f32[] parameter(0) 817 ROOT copy = f32[] copy(y) 818 } 819 820 ENTRY Parameters1.v4 { 821 constant = pred[] constant(true) 822 constant.1 = f32[] constant(56) 823 constant.2 = f32[] constant(12) 824 ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity 825 } 826 827 )" 828 }, 829 // CustomCall 830 { 831 "CustomCall", 832 R"(HloModule custom_call 833 834 ENTRY CustomCall { 835 constant = f32[1]{0} constant({12345}) 836 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" 837 } 838 839 )" 840 }, 841 // Variables with non-default names 842 { 843 "NonDefaultNames", 844 R"(HloModule add_constants_module 845 846 ENTRY add_constants { 847 foo = f32[] constant(3.14) 848 ROOT bar = f32[] add(foo, foo) 849 } 850 851 )" 852 }, 853 { 854 "Dot", 855 R"(HloModule dot 856 857 ENTRY dot { 858 a = f32[2,10]{1,0} parameter(0) 859 b = f32[10,3]{1,0} parameter(1) 860 ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} 861 } 862 863 )" 864 }, 865 }); 866 // clang-format on 867 } 868 869 class HloParserTest : public ::testing::Test, 870 public ::testing::WithParamInterface<TestData> { 871 protected: 872 static void ExpectHasSubstr(StringPiece s, StringPiece expected) { 873 EXPECT_TRUE(StringPiece(s).contains(expected)) 874 << "'" << s << "' does not contain '" << expected << "'"; 875 } 876 877 // Expects "ToString(Parse(string)) == string", that is, parses the string, 878 // asserts that it succeeded, stringifies the parsed module, and checks that 879 // the it equals the original string. 880 void ExpectEqual() { 881 const string& original = GetParam().module_string; 882 auto result = Parse(original); 883 TF_ASSERT_OK(result.status()); 884 EXPECT_EQ(original, result.ValueOrDie()->ToString( 885 HloPrintOptions().set_print_large_constants(true))); 886 } 887 }; 888 889 class HloParserShortTest : public HloParserTest { 890 protected: 891 void ExpectEqualShort() { 892 const string& original = GetParam().module_string; 893 auto result = Parse(original); 894 TF_ASSERT_OK(result.status()); 895 EXPECT_EQ(original, 896 result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); 897 } 898 }; 899 900 TEST_P(HloParserTest, Run) { ExpectEqual(); } 901 902 TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } 903 904 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, 905 ::testing::ValuesIn(CreateTestCases()), 906 TestDataToString); 907 908 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, 909 ::testing::ValuesIn(CreateShortTestCases()), 910 TestDataToString); 911 912 TEST_F(HloParserTest, Empty) { 913 const string original = ""; 914 auto result = Parse(original); 915 EXPECT_NE(tensorflow::Status::OK(), result.status()); 916 } 917 918 TEST_F(HloParserTest, Garbage) { 919 const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; 920 auto result = Parse(original); 921 EXPECT_NE(tensorflow::Status::OK(), result.status()); 922 } 923 924 TEST_F(HloParserTest, WrongOpcode) { 925 const string original = R"(HloModule wrong_opcode: 926 927 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { 928 %x = f32[]{} parameter(0) 929 %y = f32[]{} parameter(1) 930 %le = pred[]{} le(f32[]{} %x, f32[]{} %y) 931 } 932 933 )"; 934 auto result = Parse(original); 935 EXPECT_NE(tensorflow::Status::OK(), result.status()); 936 } 937 938 TEST_F(HloParserTest, WrongShape) { 939 const string original = R"(HloModule wrong_opcode: 940 941 ENTRY %blabla (x: g32[]) -> g32[] { 942 %x = g32[]{} parameter(0) 943 } 944 945 )"; 946 auto result = Parse(original); 947 EXPECT_NE(tensorflow::Status::OK(), result.status()); 948 } 949 950 TEST_F(HloParserTest, WrongOperandsSize) { 951 const string original = R"(HloModule wrong_opcode: 952 953 ENTRY %blabla (x: f32[]) -> pred[] { 954 %x = f32[]{} parameter(0) 955 %eq = pred[]{} equal-to(f32[]{} %x) 956 } 957 958 )"; 959 auto result = Parse(original); 960 EXPECT_NE(tensorflow::Status::OK(), result.status()); 961 } 962 963 TEST_F(HloParserTest, OperandNotFound) { 964 const string original = R"(HloModule operand_not_found: 965 ENTRY %blabla (x: f32[]) -> pred[] { 966 %x = f32[]{} parameter(0) 967 %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) 968 } 969 )"; 970 auto result = Parse(original); 971 EXPECT_NE(tensorflow::Status::OK(), result.status()); 972 } 973 974 TEST_F(HloParserTest, MoreConstants) { 975 const string original = R"(HloModule SelectScalarS32True_module 976 977 ENTRY %SelectScalarS32True.v4 () -> s32[] { 978 %constant.2 = pred[] constant(true) 979 %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} 980 %constant = s32[] constant(42) 981 %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) 982 } 983 984 )"; 985 auto result = Parse(original); 986 TF_EXPECT_OK(result.status()); 987 // Constant instructions have no name. The string will be parsed successfully 988 // but the constant names will not be exactly the same. 989 } 990 991 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { 992 const string original = R"(HloModule some_2_module 993 994 ENTRY %some_2 () -> f32[2] { 995 ROOT %constant = f32[2]{0} constant({1,{2}}) 996 } 997 998 )"; 999 auto result = Parse(original); 1000 EXPECT_NE(tensorflow::Status::OK(), result.status()); 1001 ExpectHasSubstr(result.status().error_message(), 1002 "expects nested array in rank 1, but sees larger"); 1003 } 1004 1005 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { 1006 const string original = R"(HloModule some_2x3_module 1007 1008 ENTRY %some_2x3 () -> f32[2,3] { 1009 ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) 1010 } 1011 1012 )"; 1013 auto result = Parse(original); 1014 EXPECT_NE(tensorflow::Status::OK(), result.status()); 1015 ExpectHasSubstr(result.status().error_message(), 1016 "expects nested array in rank 2, but sees 1"); 1017 } 1018 1019 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { 1020 const string original = R"(HloModule some_2x3x2_module 1021 1022 ENTRY %some_2x3x2 () -> f32[2,3,2] { 1023 ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) 1024 } 1025 1026 )"; 1027 auto result = Parse(original); 1028 EXPECT_NE(tensorflow::Status::OK(), result.status()); 1029 ExpectHasSubstr(result.status().error_message(), 1030 "expects 3 elements in the [0]th element"); 1031 } 1032 1033 TEST_F(HloParserTest, ConstantF16Overflow) { 1034 const string original = 1035 R"(HloModule ConstantF16Overflow_module 1036 1037 ENTRY %ConstantF16Overflow.v4 () -> f16[] { 1038 ROOT %constant = f16[] constant(-65505) 1039 } 1040 1041 )"; 1042 auto result = Parse(original); 1043 EXPECT_NE(tensorflow::Status::OK(), result.status()); 1044 ExpectHasSubstr(result.status().error_message(), 1045 "is out of range for literal's primitive type F16"); 1046 } 1047 1048 TEST_F(HloParserTest, ConstantWithExp) { 1049 const string original = R"(HloModule ConstantWithExp_module 1050 1051 ENTRY %ConstantWithExp.v4 () -> f32[] { 1052 %constant.1 = f32[] constant(3e+2) 1053 } 1054 1055 )"; 1056 auto result = Parse(original); 1057 TF_EXPECT_OK(result.status()); 1058 // The string will be parsed successfully but the output strings are not 1059 // exactly the same, because "3e2" is parsed into value 300 and will be 1060 // printed as "300". 1061 } 1062 1063 TEST_F(HloParserTest, AttibutesAnyOrder) { 1064 const string original = R"(HloModule any_order_module 1065 1066 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { 1067 %input = f32[1,2,1]{2,1,0} parameter(0) 1068 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) 1069 %filter = f32[1,1,1]{2,1,0} parameter(1) 1070 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} 1071 } 1072 1073 )"; 1074 TF_EXPECT_OK(Parse(original).status()); 1075 } 1076 1077 TEST_F(HloParserTest, InvalidDimLabels) { 1078 string prefix = R"(HloModule invalid_dim_labels_module 1079 1080 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { 1081 %input = f32[1,2,1]{2,1,0} parameter(0) 1082 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) 1083 %filter = f32[1,1,1]{2,1,0} parameter(1) 1084 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )"; 1085 string suffix = R"( 1086 } 1087 1088 )"; 1089 1090 ExpectHasSubstr( 1091 Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) 1092 .status() 1093 .error_message(), 1094 "expects dim labels pattern"); 1095 1096 ExpectHasSubstr(Parse(tensorflow::strings::StrCat( 1097 prefix, ",dim_labels=010_1100->010", suffix)) 1098 .status() 1099 .error_message(), 1100 "must have the same rank"); 1101 } 1102 1103 TEST_F(HloParserTest, UnexpectedAttribute) { 1104 const string original = R"(HloModule unexpected_attr_module 1105 1106 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { 1107 %recv = (f32[], u32[]) recv(), channel_id=15 1108 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 1109 ROOT %constant = f32[] constant(2.1) 1110 %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv 1111 %send-done = () send-done((f32[], u32[]) %send), channel_id=16 1112 } 1113 1114 )"; 1115 ExpectHasSubstr(Parse(original).status().error_message(), 1116 "unexpected attribute calls"); 1117 } 1118 1119 TEST_F(HloParserTest, MissingAttribute) { 1120 const string original = R"(HloModule missing_attr_module 1121 1122 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { 1123 %recv = (f32[], u32[]) recv(), channel_id=15 1124 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 1125 ROOT %constant = f32[] constant(-2.1) 1126 %send = (f32[], u32[]) send(f32[] %constant) 1127 %send-done = () send-done((f32[], u32[]) %send), channel_id=16 1128 } 1129 1130 )"; 1131 ExpectHasSubstr(Parse(original).status().error_message(), 1132 "attribute channel_id is expected but not seen"); 1133 } 1134 1135 TEST_F(HloParserTest, PredecessorUndefined) { 1136 const string original = R"(HloModule pre_not_found_module 1137 1138 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { 1139 %recv = (f32[], u32[]) recv(), channel_id=15 1140 %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 1141 ROOT %constant = f32[] constant(2.1) 1142 %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} 1143 %send-done = () send-done((f32[], u32[]) %send), channel_id=16 1144 } 1145 1146 )"; 1147 ExpectHasSubstr(Parse(original).status().error_message(), 1148 "'done' is not defined"); 1149 } 1150 1151 TEST_F(HloParserTest, SliceAllowOmitStride1) { 1152 const string original = R"(HloModule slice_module 1153 1154 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { 1155 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) 1156 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]} 1157 } 1158 1159 )"; 1160 TF_EXPECT_OK(Parse(original).status()); 1161 } 1162 1163 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { 1164 const string original = R"(HloModule window_pad_module 1165 1166 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { 1167 %input = f32[1,2,1]{2,1,0} parameter(0) 1168 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) 1169 %filter = f32[1,1,1]{2,1,0} parameter(1) 1170 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1} 1171 } 1172 1173 )"; 1174 ExpectHasSubstr(Parse(original).status().error_message(), 1175 "expects padding_low and padding_high separated by '_'"); 1176 } 1177 1178 TEST_F(HloParserTest, CommaBetweenSubAttributes) { 1179 const string original = R"(HloModule test_comma_module 1180 1181 ENTRY %test_comma.v4 () -> f32[] { 1182 ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} 1183 } 1184 1185 )"; 1186 TF_EXPECT_OK(Parse(original).status()); 1187 } 1188 1189 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { 1190 const string original = R"(HloModule custom_call: 1191 1192 ENTRY %CustomCall () -> f32[1] { 1193 %constant = f32[1]{0} constant({12345}) 1194 ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" 1195 })"; 1196 ExpectHasSubstr(Parse(original).status().error_message(), 1197 "Shape of computation CustomCall, f32[1], is not compatible " 1198 "with that of its root instruction foo, f32[1,2,3]"); 1199 } 1200 1201 TEST_F(HloParserTest, EntryComputationWithLayout) { 1202 const string original = R"(HloModule layout: 1203 add_F32.v3 { 1204 lhs = f32[] parameter(0) 1205 rhs = f32[] parameter(1) 1206 ROOT add = f32[] add(lhs, rhs) 1207 } 1208 1209 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { 1210 input = f32[8,16,256]{0,1,2} parameter(0) 1211 constant = f32[] constant(0) 1212 ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 1213 })"; 1214 1215 auto module = Parse(original); 1216 TF_ASSERT_OK(module.status()); 1217 auto program_layout = module.ValueOrDie()->entry_computation_layout(); 1218 ASSERT_EQ(program_layout.parameter_count(), 1); 1219 auto param_layout = program_layout.parameter_layout(0).layout(); 1220 auto result_layout = program_layout.result_layout().layout(); 1221 EXPECT_TRUE( 1222 LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout)) 1223 << "actual layout of parameter(0) is " 1224 << LayoutUtil::HumanString(param_layout); 1225 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout)) 1226 << "actual layout of result is " 1227 << LayoutUtil::HumanString(result_layout); 1228 } 1229 1230 TEST_F(HloParserTest, NoEntry) { 1231 const string original = R"(HloModule no_entry: 1232 c1 { 1233 const1 = f32[1]{0} constant({12345}) 1234 } 1235 c2 { 1236 const2 = f32[1]{0} constant({67890}) 1237 })"; 1238 auto module = Parse(original); 1239 TF_ASSERT_OK(module.status()); 1240 EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); 1241 } 1242 1243 TEST_F(HloParserTest, NoRoot) { 1244 const string original = R"(HloModule no_root: 1245 ENTRY consts { 1246 first = f32[1]{0} constant({12345}) 1247 last = f32[1]{0} constant({67890}) 1248 })"; 1249 auto module = Parse(original); 1250 TF_ASSERT_OK(module.status()); 1251 EXPECT_EQ( 1252 module.ValueOrDie()->entry_computation()->root_instruction()->name(), 1253 "last"); 1254 } 1255 1256 TEST_F(HloParserTest, MultipleEntries) { 1257 const string original = R"(HloModule multiple_entries: 1258 ENTRY c1 { 1259 const1 = f32[1]{0} constant({12345}) 1260 } 1261 ENTRY c2 { 1262 const2 = f32[1]{0} constant({67890}) 1263 })"; 1264 ExpectHasSubstr(Parse(original).status().error_message(), 1265 "expects only one ENTRY"); 1266 } 1267 1268 TEST_F(HloParserTest, MultipleRoots) { 1269 const string original = R"(HloModule multiple_roots: 1270 ENTRY consts { 1271 ROOT const1 = f32[1]{0} constant({12345}) 1272 ROOT const2 = f32[1]{0} constant({12345}) 1273 })"; 1274 ExpectHasSubstr(Parse(original).status().error_message(), 1275 "one computation should have only one ROOT"); 1276 } 1277 1278 TEST_F(HloParserTest, InstructionExists) { 1279 const string original = R"(HloModule comp_exists 1280 c1 { 1281 instr = f32[1]{0} constant({12345}) 1282 } 1283 c2 { 1284 instr = f32[1]{0} constant({67890}) 1285 })"; 1286 1287 ExpectHasSubstr(Parse(original).status().error_message(), 1288 R"(was parsing 3:3: error: instruction previously defined here 1289 instr = f32[1]{0} constant({12345}) 1290 ^)"); 1291 } 1292 1293 TEST_F(HloParserTest, ComputationExists) { 1294 const string original = R"(HloModule comp_exists 1295 comp { 1296 const1 = f32[1]{0} constant({12345}) 1297 } 1298 comp { 1299 const2 = f32[1]{0} constant({67890}) 1300 })"; 1301 ExpectHasSubstr(Parse(original).status().error_message(), 1302 R"(was parsing 2:1: error: computation previously defined here 1303 comp { 1304 ^)"); 1305 } 1306 1307 } // namespace 1308 } // namespace tools 1309 } // namespace xla 1310