Home | History | Annotate | Download | only in parser
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
     17 
     18 #include <string>
     19 #include "tensorflow/core/lib/core/status_test_util.h"
     20 #include "tensorflow/core/lib/core/stringpiece.h"
     21 #include "tensorflow/core/platform/test.h"
     22 
     23 namespace xla {
     24 namespace tools {
     25 namespace {
     26 
     27 using tensorflow::StringPiece;
     28 
     29 struct TestData {
     30   string test_name;
     31   string module_string;
     32 };
     33 
     34 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
     35   return data.param.test_name;
     36 }
     37 
     38 // For each string below, we check that:
     39 //  - we parse it to an HloModule successfully, and
     40 //  - the stringification of the resulting HloModule is equal to our original
     41 //    string.
     42 std::vector<TestData> CreateTestCases() {
     43   // clang-format off
     44   return std::vector<TestData>({
     45 // ax + y
     46 {
     47 "AxpyParam",
     48 R"(HloModule axpy_module
     49 
     50 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
     51   %alpha = f32[] parameter(0)
     52   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
     53   %x = f32[2,4]{1,0} parameter(1)
     54   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
     55   %y = f32[2,4]{1,0} parameter(2)
     56   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
     57 }
     58 
     59 )"
     60 },
     61 // pred constant
     62 {
     63 "ConstantPred",
     64 R"(HloModule constant_pred_module
     65 
     66 ENTRY %constant_pred () -> pred[] {
     67   ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}
     68 }
     69 
     70 )"
     71 },
     72 // s32 constant
     73 {
     74 "ConstantS32",
     75 R"(HloModule constant_s32_module
     76 
     77 ENTRY %constant_s32 () -> s32[] {
     78   ROOT %constant = s32[] constant(-42)
     79 }
     80 
     81 )"
     82 },
     83 // f32 constant, but the value is not a decimal
     84 {
     85 "ConstantF32",
     86 R"(HloModule ConstantF32_module
     87 
     88 ENTRY %ConstantF32.v4 () -> f32[] {
     89   ROOT %constant = f32[] constant(42)
     90 }
     91 
     92 )"
     93 },
     94 // f32 constant, rank 1 empty array.
     95 {
     96 "ConstantF32R1Empty",
     97 R"(HloModule ConstantF32Empty_module
     98 
     99 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
    100   ROOT %constant = f32[0]{0} constant({})
    101 }
    102 
    103 )"
    104 },
    105 // f32 constant, rank 4 empty array.
    106 {
    107 "ConstantF32R4Empty",
    108 R"(HloModule ConstantF32R4Empty_module
    109 
    110 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
    111   ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } })
    112 }
    113 
    114 )"
    115 },
    116 // constant 4D
    117 {
    118 "Constant4D",
    119 R"(HloModule Small_3x2x1x1_module
    120 
    121 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
    122   ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
    123 }
    124 
    125 )"
    126 },
    127 // non-finite constants: nan, inf, -inf
    128 {
    129 "ConstantNonFinite",
    130 R"(HloModule IsFiniteR1F32s_module
    131 
    132 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
    133   %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
    134   ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
    135 }
    136 
    137 )"
    138 },
    139 // constant f16
    140 {
    141 "ConstantF16",
    142 R"(HloModule ConstantF16_module
    143 
    144 ENTRY %ConstantF16.v4 () -> f16[] {
    145   ROOT %constant = f16[] constant(500)
    146 }
    147 
    148 )"
    149 },
    150 // bf16
    151 {
    152 "BF16",
    153 R"(HloModule BF16
    154 
    155 ENTRY %BF16.v4 () -> bf16[] {
    156   ROOT %constant = bf16[] constant(500)
    157 }
    158 
    159 )"
    160 },
    161 // constant + constant
    162 {
    163 "AddConstants",
    164 R"(HloModule add_constants_module
    165 
    166 ENTRY %add_constants () -> f32[] {
    167   %constant = f32[] constant(3.14)
    168   ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
    169 }
    170 
    171 )"
    172 },
    173 // tuple constant
    174 {
    175 "TupleConstant",
    176 R"(HloModule TupleConstant_module
    177 
    178 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
    179   ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
    180 }
    181 
    182 )"
    183 },
    184 // v1 > v2 ? v1 : v2
    185 {
    186 "SelectR1F32",
    187 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
    188 
    189 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
    190   %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
    191   %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
    192   %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
    193   ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
    194 }
    195 
    196 )"
    197 },
    198 // empty tuple
    199 {
    200 "EmptyTupleCreate",
    201 R"(HloModule EmptyTupleCreate_module
    202 
    203 ENTRY %EmptyTupleCreate.v1 () -> () {
    204   ROOT %tuple = () tuple()
    205 }
    206 
    207 )"
    208 },
    209 // tuple
    210 {
    211 "TupleCreate",
    212 R"(HloModule TupleCreate_module
    213 
    214 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
    215   %v1 = f32[] parameter(0)
    216   %v2 = f32[3]{0} parameter(1)
    217   %v3 = f32[2,3]{1,0} parameter(2)
    218   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
    219 }
    220 
    221 )"
    222 },
    223 {
    224 "ShardedTupleCreate",
    225 R"(HloModule ShardedTupleCreate_module
    226 
    227 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
    228   %v1 = f32[] parameter(0)
    229   %v2 = f32[3]{0} parameter(1)
    230   %v3 = f32[2,3]{1,0} parameter(2)
    231   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
    232 }
    233 
    234 )"
    235 },
    236 // int32 result = 0;
    237 // while (result < 5) { result = result + 1; }
    238 {
    239 "WhileWithScalarS32Result",
    240 R"(HloModule WhileWithScalarS32Result_module
    241 
    242 %body.v3 (prev.1: s32[]) -> s32[] {
    243   %constant = s32[] constant(1)
    244   %prev.1 = s32[] parameter(0)
    245   ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
    246 }
    247 
    248 %condition.v3 (prev.2: s32[]) -> pred[] {
    249   %constant.1 = s32[] constant(5)
    250   %prev.2 = s32[] parameter(0)
    251   ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2)
    252 }
    253 
    254 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
    255   %constant.2 = s32[] constant(0)
    256   ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
    257 }
    258 
    259 )"
    260 },
    261 // send and recv
    262 {
    263 "SendRecv",
    264 R"(HloModule TwoSendRecvBothWayRecvFist_module
    265 
    266 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
    267   %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
    268   ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
    269   %constant = f32[] constant(2.1), sharding={maximal device=0}
    270   %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
    271   %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
    272 }
    273 
    274 )"
    275 },
    276 // get-tuple-element
    277 {
    278 "GetTupleElement",
    279 R"(HloModule GetTupleElement_module
    280 
    281 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
    282   %constant = f32[3]{0} constant({1, 2, 3})
    283   %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } })
    284   %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
    285   ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
    286 }
    287 
    288 )"
    289 },
    290 // call
    291 {
    292 "Call",
    293 R"(HloModule CallR0F32IdentityScalar_module
    294 
    295 %Identity.v1 (x: f32[]) -> f32[] {
    296   ROOT %x = f32[] parameter(0)
    297 }
    298 
    299 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
    300   %constant = f32[] constant(42)
    301   ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
    302 }
    303 
    304 )"
    305 },
    306 // reduce window
    307 {
    308 "ReduceWindow",
    309 R"(HloModule R4UnitWindow_module
    310 
    311 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
    312   %lhs = f32[] parameter(0)
    313   %rhs = f32[] parameter(1)
    314   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
    315 }
    316 
    317 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
    318   %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
    319   %constant = f32[] constant(0)
    320   ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
    321 }
    322 
    323 )"
    324 },
    325 // reduce window on scalar
    326 {
    327 "ReduceWindowScalar",
    328 R"(HloModule reduce_window_scalar
    329 
    330 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
    331   %lhs = f32[] parameter(0)
    332   %rhs = f32[] parameter(1)
    333   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
    334 }
    335 
    336 ENTRY %R4UnitWindowScalar () -> f32[] {
    337   %constant = f32[] constant(42)
    338   %constant.1 = f32[] constant(1)
    339   ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
    340 }
    341 
    342 )"
    343 },
    344 // convolution
    345 {
    346 "Convolution",
    347 R"(HloModule Convolve1D1Window_0_module
    348 
    349 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
    350   %input = f32[1,2,1]{2,1,0} parameter(0)
    351   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
    352   %filter = f32[1,1,1]{2,1,0} parameter(1)
    353   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
    354 }
    355 
    356 )"
    357 },
    358 // convolution rank 2
    359 {
    360 "ConvolutionR2",
    361 R"(HloModule ConvolveR2_module
    362 
    363 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
    364   %input = f32[1,2]{1,0} parameter(0)
    365   %filter = f32[1,1]{1,0} parameter(1)
    366   ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
    367 }
    368 
    369 )"
    370 },
    371 // convolution backward
    372 {
    373 "ConvolutionBackward",
    374 R"(HloModule ConvolveBackward_module
    375 
    376 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
    377   %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
    378   %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
    379   ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
    380 }
    381 
    382 )"
    383 },
    384 // reverse(constant)
    385 {
    386 "Reverse4D",
    387 R"(HloModule Reverse4DFloatArrayOnDim01_module
    388 
    389 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
    390   %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
    391   ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
    392 }
    393 
    394 )"
    395 },
    396 // concat
    397 {
    398 "Concat",
    399 R"(HloModule Concat2x3With2x5_module
    400 
    401 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
    402   %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } })
    403   %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
    404   ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
    405 }
    406 
    407 )"
    408 },
    409 // select and scatter
    410 {
    411 "SelectAndScatter",
    412 R"(HloModule R4F32OverlapSmall_module
    413 
    414 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
    415   %lhs = f32[] parameter(0)
    416   %rhs = f32[] parameter(1)
    417   ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
    418 }
    419 
    420 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
    421   %lhs.1 = f32[] parameter(0)
    422   %rhs.1 = f32[] parameter(1)
    423   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
    424 }
    425 
    426 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
    427   %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
    428   %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
    429   %constant.2 = f32[] constant(0)
    430   ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
    431 }
    432 
    433 )"
    434 },
    435 // select and scatter on scalar
    436 {
    437 "SelectAndScatterScalar",
    438 R"(HloModule select_and_scatter_scalar
    439 
    440 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
    441   %lhs = f32[] parameter(0)
    442   %rhs = f32[] parameter(1)
    443   ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
    444 }
    445 
    446 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
    447   %lhs.1 = f32[] parameter(0)
    448   %rhs.1 = f32[] parameter(1)
    449   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
    450 }
    451 
    452 ENTRY %SelectAndScatterScalar () -> f32[] {
    453   %constant = f32[] constant(42)
    454   %constant.1 = f32[] constant(1)
    455   %constant.2 = f32[] constant(2)
    456   ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
    457 }
    458 
    459 )"
    460 },
    461 // slice
    462 {
    463 "Slice",
    464 R"(HloModule slice_module
    465 
    466 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
    467   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
    468   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
    469 }
    470 
    471 )"
    472 },
    473 // slice, no stride
    474 {
    475 "SliceNoStride",
    476 R"(HloModule Slice3x3x3_To_1x3x3_F32_module
    477 
    478 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
    479   %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
    480   ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
    481 }
    482 
    483 )"
    484 },
    485 // slice R0
    486 {
    487 "SliceR0",
    488 R"(HloModule SliceR0_module
    489 
    490 ENTRY %SliceR0.v2 () -> s32[] {
    491   %constant = s32[] constant(1)
    492   ROOT %slice = s32[] slice(s32[] %constant), slice={}
    493 }
    494 
    495 )"
    496 },
    497 // transpose
    498 {
    499 "Transpose",
    500 R"(HloModule Transpose_module
    501 
    502 ENTRY %Transpose.v2 () -> s32[1,2,3] {
    503   %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } })
    504   ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
    505 }
    506 
    507 )"
    508 },
    509 // Dynamic slice
    510 {
    511 "DynamicSlice",
    512 R"(HloModule DynamicSlice_module
    513 
    514 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
    515   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
    516   %constant = s32[1]{0} constant({0})
    517   %start_index = s32[1]{0} parameter(1)
    518   %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
    519   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
    520 }
    521 
    522 )"
    523 },
    524 // Dynamic update slice
    525 {
    526 "DynamicUpdateSlice",
    527 R"(HloModule DynamicUpdateSlice_module
    528 
    529 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
    530   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
    531   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
    532   %start_indices = s32[4]{0} parameter(2)
    533   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
    534 }
    535 
    536 )"
    537 },
    538 // batch norm training
    539 {
    540 "BatchNormTraining",
    541 R"(HloModule BasicTraining_module
    542 
    543 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
    544   %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } })
    545   %constant.1 = f32[2]{0} constant({2, 3})
    546   %constant.2 = f32[2]{0} constant({1, 2})
    547   ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
    548 }
    549 
    550 )"
    551 },
    552 // batch norm inference
    553 {
    554 "BatchNormInference",
    555 R"(HloModule BatchNormInference_module
    556 
    557 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
    558   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
    559   %offset = f32[2]{0} parameter(1)
    560   %scale = f32[2]{0} parameter(2)
    561   %mean = f32[2]{0} parameter(3)
    562   %variance = f32[2]{0} parameter(4)
    563   ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
    564 }
    565 
    566 )"
    567 },
    568 // batch norm grad
    569 {
    570 "BatchNormGrad",
    571 R"(HloModule BatchNormGrad_module
    572 
    573 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
    574   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
    575   %scale = f32[2]{0} parameter(1)
    576   %mean = f32[2]{0} parameter(2)
    577   %variance = f32[2]{0} parameter(3)
    578   %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
    579   ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
    580 }
    581 
    582 )"
    583 },
    584 // fft
    585 {
    586 "Fft",
    587 R"(HloModule Fft_module
    588 
    589 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
    590   %input = c64[8,32]{1,0} parameter(0)
    591   ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
    592 }
    593 
    594 )"
    595 },
    596 // ifft
    597 {
    598 "Ifft2d",
    599 R"(HloModule Ifft2d_module
    600 
    601 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
    602   %input = c64[5,8,32]{2,1,0} parameter(0)
    603   ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
    604 }
    605 
    606 )"
    607 },
    608 // rfft2d
    609 {
    610 "Rfft2d",
    611 R"(HloModule Rfft2d_module
    612 
    613 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
    614   %input = f32[5,64,32]{2,1,0} parameter(0)
    615   ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
    616 }
    617 
    618 )"
    619 },
    620 // irfft3d
    621 {
    622 "Irfft3d",
    623 R"(HloModule Irfft3d_module
    624 
    625 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
    626   %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
    627   ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
    628 }
    629 
    630 )"
    631 },
    632 // pad
    633 {
    634 "Pad",
    635 R"(HloModule Pad1DS3Array_module
    636 
    637 ENTRY %Pad1DS3Array.v3 () -> f32[8] {
    638   %constant = f32[3]{0} constant({1, 2, 3})
    639   %constant.1 = f32[] constant(0.1)
    640   ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
    641 }
    642 
    643 )"
    644 },
    645 // pad has interior
    646 {
    647 "PadHasInterior",
    648 R"(HloModule PadHasInterior_module
    649 
    650 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
    651   %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
    652   %constant = f32[] constant(-5.123)
    653   ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
    654 }
    655 
    656 )"
    657 },
    658 // Negative padding
    659 {
    660 "PadHasNegativePadding",
    661 R"(HloModule PadHasNegativePadding_module
    662 
    663 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] {
    664   %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
    665   %constant = f32[] constant(-5.123)
    666   ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
    667 }
    668 
    669 )"
    670 },
    671 // fusion
    672 {
    673 "Fusion",
    674 R"(HloModule fusion_module
    675 
    676 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
    677   %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
    678   %constant.1.param_1 = f32[2]{0} parameter(1)
    679   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
    680   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
    681 }
    682 
    683 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
    684   %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
    685   %constant.1 = f32[2]{0} constant({3.14, 4.25})
    686   ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
    687 }
    688 
    689 )"
    690 },
    691 {
    692 "Sparse",
    693 R"(HloModule sparse_f32
    694 
    695 ENTRY %sparse () -> f32[2,3,4] {
    696   ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3})
    697 }
    698 
    699 )"
    700 },
    701 {
    702 "SparseEmpty",
    703 R"(HloModule sparse_f32_empty
    704 
    705 ENTRY %sparse_f32_empty () -> f32[2,3,4] {
    706   ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{})
    707 }
    708 
    709 )"
    710 },
    711 {
    712 "SparseR1",
    713 R"(HloModule sparse_f32_r1
    714 
    715 ENTRY %sparse_f32_r1 () -> f32[9] {
    716   ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6})
    717 }
    718 
    719 )"
    720 },
    721   });
    722   // clang-format on
    723 }
    724 
    725 std::vector<TestData> CreateShortTestCases() {
    726   // clang-format off
    727   return std::vector<TestData>({
    728 // map
    729 {
    730 "Map",
    731 R"(HloModule MapBinaryAdder_module
    732 
    733 add_F32.v3 {
    734   lhs = f32[] parameter(0)
    735   rhs = f32[] parameter(1)
    736   ROOT add = f32[] add(lhs, rhs)
    737 }
    738 
    739 ENTRY MapBinaryAdder.v3 {
    740   param0 = f32[4]{0} parameter(0)
    741   param1 = f32[4]{0} parameter(1)
    742   ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3
    743 }
    744 
    745 )"
    746 },
    747 // reduce
    748 {
    749 "Reduce",
    750 R"(HloModule ReduceR3ToR2_module
    751 
    752 add_F32.v3 {
    753   lhs = f32[] parameter(0)
    754   rhs = f32[] parameter(1)
    755   ROOT add = f32[] add(lhs, rhs)
    756 }
    757 
    758 ENTRY ReduceR3ToR2.v3 {
    759   input = f32[8,16,256]{2,1,0} parameter(0)
    760   constant = f32[] constant(0)
    761   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
    762 }
    763 
    764 )"
    765 },
    766 // infeed/outfeed
    767 {
    768 "InfeedOutfeed",
    769 R"(HloModule outfeed_module
    770 
    771 ENTRY InfeedToOutfeed {
    772   infeed = (u32[3]{0}, pred[]) infeed()
    773   outfeed = () outfeed(infeed)
    774   ROOT infeed.1 = (u32[3]{0}, pred[]) infeed()
    775   outfeed.1 = () outfeed(infeed.1)
    776 }
    777 
    778 )"
    779 },
    780 // Rng
    781 {
    782 "Rng",
    783 R"(HloModule rng_module
    784 
    785 ENTRY Rng {
    786   constant = f32[] constant(0)
    787   constant.1 = f32[] constant(1)
    788   ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
    789 }
    790 
    791 )"
    792 },
    793 // Reduce precision
    794 {
    795 "ReducePrevison",
    796 R"(HloModule reduce_precision
    797 
    798 ENTRY ReducePrecision {
    799   constant = f32[1]{0} constant({3.14159})
    800   ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
    801 }
    802 
    803 )"
    804 },
    805 // Conditional
    806 {
    807 "Conditional",
    808 R"(HloModule conditional
    809 
    810 Negate {
    811   x = f32[] parameter(0)
    812   ROOT negate = f32[] negate(x)
    813 }
    814 
    815 Identity {
    816   y = f32[] parameter(0)
    817   ROOT copy = f32[] copy(y)
    818 }
    819 
    820 ENTRY Parameters1.v4 {
    821   constant = pred[] constant(true)
    822   constant.1 = f32[] constant(56)
    823   constant.2 = f32[] constant(12)
    824   ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
    825 }
    826 
    827 )"
    828 },
    829 // CustomCall
    830 {
    831 "CustomCall",
    832 R"(HloModule custom_call
    833 
    834 ENTRY CustomCall {
    835   constant = f32[1]{0} constant({12345})
    836   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
    837 }
    838 
    839 )"
    840 },
    841 // Variables with non-default names
    842 {
    843 "NonDefaultNames",
    844 R"(HloModule add_constants_module
    845 
    846 ENTRY add_constants {
    847   foo = f32[] constant(3.14)
    848   ROOT bar = f32[] add(foo, foo)
    849 }
    850 
    851 )"
    852 },
    853 {
    854 "Dot",
    855 R"(HloModule dot
    856 
    857 ENTRY dot {
    858   a = f32[2,10]{1,0} parameter(0)
    859   b = f32[10,3]{1,0} parameter(1)
    860   ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0}
    861 }
    862 
    863 )"
    864 },
    865   });
    866   // clang-format on
    867 }
    868 
    869 class HloParserTest : public ::testing::Test,
    870                       public ::testing::WithParamInterface<TestData> {
    871  protected:
    872   static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
    873     EXPECT_TRUE(StringPiece(s).contains(expected))
    874         << "'" << s << "' does not contain '" << expected << "'";
    875   }
    876 
    877   // Expects "ToString(Parse(string)) == string", that is, parses the string,
    878   // asserts that it succeeded, stringifies the parsed module, and checks that
    879   // the it equals the original string.
    880   void ExpectEqual() {
    881     const string& original = GetParam().module_string;
    882     auto result = Parse(original);
    883     TF_ASSERT_OK(result.status());
    884     EXPECT_EQ(original, result.ValueOrDie()->ToString(
    885                             HloPrintOptions().set_print_large_constants(true)));
    886   }
    887 };
    888 
    889 class HloParserShortTest : public HloParserTest {
    890  protected:
    891   void ExpectEqualShort() {
    892     const string& original = GetParam().module_string;
    893     auto result = Parse(original);
    894     TF_ASSERT_OK(result.status());
    895     EXPECT_EQ(original,
    896               result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
    897   }
    898 };
    899 
    900 TEST_P(HloParserTest, Run) { ExpectEqual(); }
    901 
    902 TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
    903 
    904 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
    905                         ::testing::ValuesIn(CreateTestCases()),
    906                         TestDataToString);
    907 
    908 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
    909                         ::testing::ValuesIn(CreateShortTestCases()),
    910                         TestDataToString);
    911 
    912 TEST_F(HloParserTest, Empty) {
    913   const string original = "";
    914   auto result = Parse(original);
    915   EXPECT_NE(tensorflow::Status::OK(), result.status());
    916 }
    917 
    918 TEST_F(HloParserTest, Garbage) {
    919   const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
    920   auto result = Parse(original);
    921   EXPECT_NE(tensorflow::Status::OK(), result.status());
    922 }
    923 
    924 TEST_F(HloParserTest, WrongOpcode) {
    925   const string original = R"(HloModule wrong_opcode:
    926 
    927 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
    928   %x = f32[]{} parameter(0)
    929   %y = f32[]{} parameter(1)
    930   %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
    931 }
    932 
    933 )";
    934   auto result = Parse(original);
    935   EXPECT_NE(tensorflow::Status::OK(), result.status());
    936 }
    937 
    938 TEST_F(HloParserTest, WrongShape) {
    939   const string original = R"(HloModule wrong_opcode:
    940 
    941 ENTRY %blabla (x: g32[]) -> g32[] {
    942   %x = g32[]{} parameter(0)
    943 }
    944 
    945 )";
    946   auto result = Parse(original);
    947   EXPECT_NE(tensorflow::Status::OK(), result.status());
    948 }
    949 
    950 TEST_F(HloParserTest, WrongOperandsSize) {
    951   const string original = R"(HloModule wrong_opcode:
    952 
    953 ENTRY %blabla (x: f32[]) -> pred[] {
    954   %x = f32[]{} parameter(0)
    955   %eq = pred[]{} equal-to(f32[]{} %x)
    956 }
    957 
    958 )";
    959   auto result = Parse(original);
    960   EXPECT_NE(tensorflow::Status::OK(), result.status());
    961 }
    962 
    963 TEST_F(HloParserTest, OperandNotFound) {
    964   const string original = R"(HloModule operand_not_found:
    965 ENTRY %blabla (x: f32[]) -> pred[] {
    966   %x = f32[]{} parameter(0)
    967   %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
    968 }
    969 )";
    970   auto result = Parse(original);
    971   EXPECT_NE(tensorflow::Status::OK(), result.status());
    972 }
    973 
    974 TEST_F(HloParserTest, MoreConstants) {
    975   const string original = R"(HloModule SelectScalarS32True_module
    976 
    977 ENTRY %SelectScalarS32True.v4 () -> s32[] {
    978   %constant.2 = pred[] constant(true)
    979   %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
    980   %constant = s32[] constant(42)
    981   %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
    982 }
    983 
    984 )";
    985   auto result = Parse(original);
    986   TF_EXPECT_OK(result.status());
    987   // Constant instructions have no name. The string will be parsed successfully
    988   // but the constant names will not be exactly the same.
    989 }
    990 
    991 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
    992   const string original = R"(HloModule some_2_module
    993 
    994 ENTRY %some_2 () -> f32[2] {
    995   ROOT %constant = f32[2]{0} constant({1,{2}})
    996 }
    997 
    998 )";
    999   auto result = Parse(original);
   1000   EXPECT_NE(tensorflow::Status::OK(), result.status());
   1001   ExpectHasSubstr(result.status().error_message(),
   1002                   "expects nested array in rank 1, but sees larger");
   1003 }
   1004 
   1005 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
   1006   const string original = R"(HloModule some_2x3_module
   1007 
   1008 ENTRY %some_2x3 () -> f32[2,3] {
   1009   ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6})
   1010 }
   1011 
   1012 )";
   1013   auto result = Parse(original);
   1014   EXPECT_NE(tensorflow::Status::OK(), result.status());
   1015   ExpectHasSubstr(result.status().error_message(),
   1016                   "expects nested array in rank 2, but sees 1");
   1017 }
   1018 
   1019 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
   1020   const string original = R"(HloModule some_2x3x2_module
   1021 
   1022 ENTRY %some_2x3x2 () -> f32[2,3,2] {
   1023   ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
   1024 }
   1025 
   1026 )";
   1027   auto result = Parse(original);
   1028   EXPECT_NE(tensorflow::Status::OK(), result.status());
   1029   ExpectHasSubstr(result.status().error_message(),
   1030                   "expects 3 elements in the [0]th element");
   1031 }
   1032 
   1033 TEST_F(HloParserTest, ConstantF16Overflow) {
   1034   const string original =
   1035       R"(HloModule ConstantF16Overflow_module
   1036 
   1037 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
   1038   ROOT %constant = f16[] constant(-65505)
   1039 }
   1040 
   1041 )";
   1042   auto result = Parse(original);
   1043   EXPECT_NE(tensorflow::Status::OK(), result.status());
   1044   ExpectHasSubstr(result.status().error_message(),
   1045                   "is out of range for literal's primitive type F16");
   1046 }
   1047 
   1048 TEST_F(HloParserTest, ConstantWithExp) {
   1049   const string original = R"(HloModule ConstantWithExp_module
   1050 
   1051 ENTRY %ConstantWithExp.v4 () -> f32[] {
   1052   %constant.1 = f32[] constant(3e+2)
   1053 }
   1054 
   1055 )";
   1056   auto result = Parse(original);
   1057   TF_EXPECT_OK(result.status());
   1058   // The string will be parsed successfully but the output strings are not
   1059   // exactly the same, because "3e2" is parsed into value 300 and will be
   1060   // printed as "300".
   1061 }
   1062 
   1063 TEST_F(HloParserTest, AttibutesAnyOrder) {
   1064   const string original = R"(HloModule any_order_module
   1065 
   1066 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
   1067   %input = f32[1,2,1]{2,1,0} parameter(0)
   1068   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
   1069   %filter = f32[1,1,1]{2,1,0} parameter(1)
   1070   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
   1071 }
   1072 
   1073 )";
   1074   TF_EXPECT_OK(Parse(original).status());
   1075 }
   1076 
   1077 TEST_F(HloParserTest, InvalidDimLabels) {
   1078   string prefix = R"(HloModule invalid_dim_labels_module
   1079 
   1080 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
   1081   %input = f32[1,2,1]{2,1,0} parameter(0)
   1082   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
   1083   %filter = f32[1,1,1]{2,1,0} parameter(1)
   1084   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
   1085   string suffix = R"(
   1086 }
   1087 
   1088 )";
   1089 
   1090   ExpectHasSubstr(
   1091       Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix))
   1092           .status()
   1093           .error_message(),
   1094       "expects dim labels pattern");
   1095 
   1096   ExpectHasSubstr(Parse(tensorflow::strings::StrCat(
   1097                             prefix, ",dim_labels=010_1100->010", suffix))
   1098                       .status()
   1099                       .error_message(),
   1100                   "must have the same rank");
   1101 }
   1102 
   1103 TEST_F(HloParserTest, UnexpectedAttribute) {
   1104   const string original = R"(HloModule unexpected_attr_module
   1105 
   1106 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
   1107   %recv = (f32[], u32[]) recv(), channel_id=15
   1108   %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   1109   ROOT %constant = f32[] constant(2.1)
   1110   %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
   1111   %send-done = () send-done((f32[], u32[]) %send), channel_id=16
   1112 }
   1113 
   1114 )";
   1115   ExpectHasSubstr(Parse(original).status().error_message(),
   1116                   "unexpected attribute calls");
   1117 }
   1118 
   1119 TEST_F(HloParserTest, MissingAttribute) {
   1120   const string original = R"(HloModule missing_attr_module
   1121 
   1122 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
   1123   %recv = (f32[], u32[]) recv(), channel_id=15
   1124   %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   1125   ROOT %constant = f32[] constant(-2.1)
   1126   %send = (f32[], u32[]) send(f32[] %constant)
   1127   %send-done = () send-done((f32[], u32[]) %send), channel_id=16
   1128 }
   1129 
   1130 )";
   1131   ExpectHasSubstr(Parse(original).status().error_message(),
   1132                   "attribute channel_id is expected but not seen");
   1133 }
   1134 
   1135 TEST_F(HloParserTest, PredecessorUndefined) {
   1136   const string original = R"(HloModule pre_not_found_module
   1137 
   1138 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
   1139   %recv = (f32[], u32[]) recv(), channel_id=15
   1140   %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   1141   ROOT %constant = f32[] constant(2.1)
   1142   %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
   1143   %send-done = () send-done((f32[], u32[]) %send), channel_id=16
   1144 }
   1145 
   1146 )";
   1147   ExpectHasSubstr(Parse(original).status().error_message(),
   1148                   "'done' is not defined");
   1149 }
   1150 
   1151 TEST_F(HloParserTest, SliceAllowOmitStride1) {
   1152   const string original = R"(HloModule slice_module
   1153 
   1154 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
   1155   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
   1156   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
   1157 }
   1158 
   1159 )";
   1160   TF_EXPECT_OK(Parse(original).status());
   1161 }
   1162 
   1163 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
   1164   const string original = R"(HloModule window_pad_module
   1165 
   1166 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
   1167   %input = f32[1,2,1]{2,1,0} parameter(0)
   1168   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
   1169   %filter = f32[1,1,1]{2,1,0} parameter(1)
   1170   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
   1171 }
   1172 
   1173 )";
   1174   ExpectHasSubstr(Parse(original).status().error_message(),
   1175                   "expects padding_low and padding_high separated by '_'");
   1176 }
   1177 
   1178 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
   1179   const string original = R"(HloModule test_comma_module
   1180 
   1181 ENTRY %test_comma.v4 () -> f32[] {
   1182   ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
   1183 }
   1184 
   1185 )";
   1186   TF_EXPECT_OK(Parse(original).status());
   1187 }
   1188 
   1189 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
   1190   const string original = R"(HloModule custom_call:
   1191 
   1192 ENTRY %CustomCall () -> f32[1] {
   1193   %constant = f32[1]{0} constant({12345})
   1194   ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
   1195 })";
   1196   ExpectHasSubstr(Parse(original).status().error_message(),
   1197                   "Shape of computation CustomCall, f32[1], is not compatible "
   1198                   "with that of its root instruction foo, f32[1,2,3]");
   1199 }
   1200 
   1201 TEST_F(HloParserTest, EntryComputationWithLayout) {
   1202   const string original = R"(HloModule layout:
   1203 add_F32.v3 {
   1204   lhs = f32[] parameter(0)
   1205   rhs = f32[] parameter(1)
   1206   ROOT add = f32[] add(lhs, rhs)
   1207 }
   1208 
   1209 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
   1210   input = f32[8,16,256]{0,1,2} parameter(0)
   1211   constant = f32[] constant(0)
   1212   ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
   1213 })";
   1214 
   1215   auto module = Parse(original);
   1216   TF_ASSERT_OK(module.status());
   1217   auto program_layout = module.ValueOrDie()->entry_computation_layout();
   1218   ASSERT_EQ(program_layout.parameter_count(), 1);
   1219   auto param_layout = program_layout.parameter_layout(0).layout();
   1220   auto result_layout = program_layout.result_layout().layout();
   1221   EXPECT_TRUE(
   1222       LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
   1223       << "actual layout of parameter(0) is "
   1224       << LayoutUtil::HumanString(param_layout);
   1225   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
   1226       << "actual layout of result is "
   1227       << LayoutUtil::HumanString(result_layout);
   1228 }
   1229 
   1230 TEST_F(HloParserTest, NoEntry) {
   1231   const string original = R"(HloModule no_entry:
   1232 c1 {
   1233   const1 = f32[1]{0} constant({12345})
   1234 }
   1235 c2 {
   1236   const2 = f32[1]{0} constant({67890})
   1237 })";
   1238   auto module = Parse(original);
   1239   TF_ASSERT_OK(module.status());
   1240   EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
   1241 }
   1242 
   1243 TEST_F(HloParserTest, NoRoot) {
   1244   const string original = R"(HloModule no_root:
   1245 ENTRY consts {
   1246   first = f32[1]{0} constant({12345})
   1247   last = f32[1]{0} constant({67890})
   1248 })";
   1249   auto module = Parse(original);
   1250   TF_ASSERT_OK(module.status());
   1251   EXPECT_EQ(
   1252       module.ValueOrDie()->entry_computation()->root_instruction()->name(),
   1253       "last");
   1254 }
   1255 
   1256 TEST_F(HloParserTest, MultipleEntries) {
   1257   const string original = R"(HloModule multiple_entries:
   1258 ENTRY c1 {
   1259   const1 = f32[1]{0} constant({12345})
   1260 }
   1261 ENTRY c2 {
   1262   const2 = f32[1]{0} constant({67890})
   1263 })";
   1264   ExpectHasSubstr(Parse(original).status().error_message(),
   1265                   "expects only one ENTRY");
   1266 }
   1267 
   1268 TEST_F(HloParserTest, MultipleRoots) {
   1269   const string original = R"(HloModule multiple_roots:
   1270 ENTRY consts {
   1271   ROOT const1 = f32[1]{0} constant({12345})
   1272   ROOT const2 = f32[1]{0} constant({12345})
   1273 })";
   1274   ExpectHasSubstr(Parse(original).status().error_message(),
   1275                   "one computation should have only one ROOT");
   1276 }
   1277 
   1278 TEST_F(HloParserTest, InstructionExists) {
   1279   const string original = R"(HloModule comp_exists
   1280 c1 {
   1281   instr = f32[1]{0} constant({12345})
   1282 }
   1283 c2 {
   1284   instr = f32[1]{0} constant({67890})
   1285 })";
   1286 
   1287   ExpectHasSubstr(Parse(original).status().error_message(),
   1288                   R"(was parsing 3:3: error: instruction previously defined here
   1289   instr = f32[1]{0} constant({12345})
   1290   ^)");
   1291 }
   1292 
   1293 TEST_F(HloParserTest, ComputationExists) {
   1294   const string original = R"(HloModule comp_exists
   1295 comp {
   1296   const1 = f32[1]{0} constant({12345})
   1297 }
   1298 comp {
   1299   const2 = f32[1]{0} constant({67890})
   1300 })";
   1301   ExpectHasSubstr(Parse(original).status().error_message(),
   1302                   R"(was parsing 2:1: error: computation previously defined here
   1303 comp {
   1304 ^)");
   1305 }
   1306 
   1307 }  // namespace
   1308 }  // namespace tools
   1309 }  // namespace xla
   1310