Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/shape_inference.h"
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/compiler/xla/shape_util.h"
     21 #include "tensorflow/compiler/xla/test.h"
     22 #include "tensorflow/compiler/xla/test_helpers.h"
     23 #include "tensorflow/compiler/xla/types.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 #include "tensorflow/core/lib/gtl/array_slice.h"
     26 
     27 namespace xla {
     28 namespace {
     29 
     30 using ::tensorflow::gtl::ArraySlice;
     31 using ::testing::ContainsRegex;
     32 using ::testing::HasSubstr;
     33 
     34 class ShapeInferenceTest : public ::testing::Test {
     35  protected:
     36   // Some handy scalar shapes.
     37   const Shape s32_ = ShapeUtil::MakeShape(S32, {});
     38   const Shape f32_ = ShapeUtil::MakeShape(F32, {});
     39   const Shape f64_ = ShapeUtil::MakeShape(F64, {});
     40   const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
     41 
     42   // Some handy vector and matrix shapes of F32 type.
     43   // Suffix: vector_length_, matrix_rows_cols_
     44   const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
     45   const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
     46   const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
     47   const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
     48   const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
     49 
     50   // Some handy S32 arrays.
     51   const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
     52 };
     53 
     54 // Subclass for testing InferReduceShape.
     55 class ReduceShapeInferenceTest : public ShapeInferenceTest {
     56  protected:
     57   // Helper that runs reduce shape inference with the input 'arg' and given
     58   // dimensions to reduce, and checks the inferred shape is as expected. The
     59   // element type here is hard-coded to F32.
     60   void ExpectInferredReduceShape(
     61       const Shape& expected_inferred_shape, const Shape& arg,
     62       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
     63     ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
     64     auto inferred_status = ShapeInference::InferReduceShape(
     65         arg, f32_, dimensions_to_reduce, to_apply);
     66     EXPECT_IS_OK(inferred_status.status());
     67     EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
     68                                  inferred_status.ValueOrDie()));
     69   }
     70 };
     71 
     72 // Subclass for testing InferSelectAndScatterShape.
     73 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
     74  protected:
     75   SelectAndScatterShapeInferenceTest() {
     76     operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
     77     source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
     78     WindowDimension dim;
     79     dim.set_size(2);
     80     dim.set_stride(2);
     81     dim.set_padding_low(0);
     82     dim.set_padding_high(0);
     83     dim.set_window_dilation(1);
     84     dim.set_base_dilation(1);
     85     *window_.add_dimensions() = dim;
     86     *window_.add_dimensions() = dim;
     87     init_value_shape_ = ShapeUtil::MakeShape(F32, {});
     88     select_program_shape_ = ShapeUtil::MakeProgramShape(
     89         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
     90     scatter_program_shape_ = ShapeUtil::MakeProgramShape(
     91         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
     92   }
     93 
     94   Shape operand_shape_;
     95   Shape source_shape_;
     96   Window window_;
     97   Shape init_value_shape_;
     98   ProgramShape select_program_shape_;
     99   ProgramShape scatter_program_shape_;
    100 };
    101 
    102 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
    103   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    104   auto inferred_status = ShapeInference::InferUnaryOpShape(
    105       UnaryOperation::UNOP_NEGATE, matrix_shape);
    106   ASSERT_IS_OK(inferred_status.status());
    107   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
    108 }
    109 
    110 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
    111   Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
    112   auto inferred_status = ShapeInference::InferTernaryOpShape(
    113       TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple);
    114   ASSERT_IS_OK(inferred_status.status());
    115   ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
    116 }
    117 
    118 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
    119   auto inferred_status = ShapeInference::InferTernaryOpShape(
    120       TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_);
    121   ASSERT_IS_OK(inferred_status.status());
    122   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    123 }
    124 
    125 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
    126   auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
    127   auto inferred_status = ShapeInference::InferTernaryOpShape(
    128       TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_);
    129   ASSERT_IS_OK(inferred_status.status());
    130   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    131 }
    132 
    133 TEST_F(ShapeInferenceTest, SelectBadShapes) {
    134   auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
    135       TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
    136   ASSERT_FALSE(inferred_status_error1.ok());
    137   ASSERT_THAT(inferred_status_error1.status().error_message(),
    138               HasSubstr("operands to select must be the same shape"));
    139 
    140   auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
    141       TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
    142   ASSERT_FALSE(inferred_status_error2.ok());
    143   ASSERT_THAT(inferred_status_error2.status().error_message(),
    144               HasSubstr("pred operand must have PRED"));
    145 
    146   auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
    147       TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
    148       matrix_64_48_, matrix_64_48_);
    149   ASSERT_FALSE(inferred_status_error3.ok());
    150   ASSERT_THAT(inferred_status_error3.status().error_message(),
    151               HasSubstr("with non-scalar predicate with dimensionality"));
    152 
    153   // Tuples have a TUPLE element type and cannot be the pred of a select.
    154   auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
    155       TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}),
    156       ShapeUtil::MakeTupleShape({f32_, f32_}),
    157       ShapeUtil::MakeTupleShape({f32_, f32_}));
    158   ASSERT_FALSE(inferred_status_error4.ok());
    159   ASSERT_THAT(inferred_status_error4.status().error_message(),
    160               HasSubstr("pred operand must have PRED element type"));
    161 }
    162 
    163 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
    164   auto inferred_status = ShapeInference::InferTernaryOpShape(
    165       TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
    166       matrix_64_48_);
    167   ASSERT_IS_OK(inferred_status.status());
    168   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    169 }
    170 
    171 TEST_F(ShapeInferenceTest, ClampAllScalar) {
    172   auto inferred_status = ShapeInference::InferTernaryOpShape(
    173       TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
    174   ASSERT_IS_OK(inferred_status.status());
    175   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
    176 }
    177 
    178 TEST_F(ShapeInferenceTest, ClampMinScalar) {
    179   auto inferred_status = ShapeInference::InferTernaryOpShape(
    180       TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
    181   ASSERT_IS_OK(inferred_status.status());
    182   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    183 }
    184 
    185 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
    186   auto inferred_status = ShapeInference::InferTernaryOpShape(
    187       TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
    188   ASSERT_IS_OK(inferred_status.status());
    189   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    190 }
    191 
    192 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
    193   auto inferred_status = ShapeInference::InferTernaryOpShape(
    194       TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
    195   ASSERT_IS_OK(inferred_status.status());
    196   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    197 }
    198 
    199 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
    200   auto inferred_status = ShapeInference::InferTernaryOpShape(
    201       TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
    202   ASSERT_IS_OK(inferred_status.status());
    203   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    204 }
    205 
    206 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
    207   auto inferred_status = ShapeInference::InferTernaryOpShape(
    208       TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
    209   ASSERT_IS_OK(inferred_status.status());
    210   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    211 }
    212 
    213 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
    214   auto inferred_status = ShapeInference::InferTernaryOpShape(
    215       TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
    216   ASSERT_IS_OK(inferred_status.status());
    217   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
    218 }
    219 
    220 TEST_F(ShapeInferenceTest, ClampBadShapes) {
    221   // Type mismatch
    222   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    223                    TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
    224                    .ok());
    225   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    226                    TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
    227                    .ok());
    228   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    229                    TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
    230                    .ok());
    231   // Dimension mismatch
    232   ASSERT_FALSE(
    233       ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
    234                                           vector_64_, vector_32_, vector_32_)
    235           .ok());
    236   ASSERT_FALSE(
    237       ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
    238                                           vector_32_, vector_64_, vector_32_)
    239           .ok());
    240   ASSERT_FALSE(
    241       ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
    242                                           vector_32_, vector_32_, vector_64_)
    243           .ok());
    244   // Dimension mismatch, where one operand is a scalar
    245   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    246                    TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
    247                    .ok());
    248   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    249                    TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
    250                    .ok());
    251   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
    252                    TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
    253                    .ok());
    254 }
    255 
    256 TEST_F(ShapeInferenceTest, Complex) {
    257   auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
    258                            const tensorflow::gtl::ArraySlice<int64>& bcast) {
    259     return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
    260                                               lhs, rhs, bcast);
    261   };
    262   // Inputs must be FP.
    263   ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
    264   ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
    265   // Component types must match.
    266   ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
    267   // Only F32->C64 supported.
    268   ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
    269   // Validate correct uses.
    270   Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
    271   TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
    272   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
    273   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
    274   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
    275   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
    276   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
    277   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
    278   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
    279 
    280   Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
    281   TF_ASSERT_OK_AND_ASSIGN(result,
    282                           complex_shape(vector_64_, matrix_32_64_, {1}));
    283   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
    284   TF_ASSERT_OK_AND_ASSIGN(result,
    285                           complex_shape(matrix_32_64_, vector_64_, {1}));
    286   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
    287   TF_ASSERT_OK_AND_ASSIGN(result,
    288                           complex_shape(matrix_32_64_, matrix_32_64_, {}));
    289   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
    290   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
    291   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
    292 }
    293 
    294 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
    295   StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
    296       VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
    297   ASSERT_IS_OK(result.status());
    298   ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
    299                                ShapeUtil::MakeTupleShape({s32_, f32_})));
    300 }
    301 
    302 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
    303   Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
    304   Window window;
    305   WindowDimension dim;
    306   dim.set_size(2);
    307   dim.set_stride(2);
    308   dim.set_padding_low(0);
    309   dim.set_padding_high(0);
    310   dim.set_window_dilation(1);
    311   dim.set_base_dilation(1);
    312   *window.add_dimensions() = dim;
    313   *window.add_dimensions() = dim;
    314   Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
    315   Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
    316   Shape float_scalar = ShapeUtil::MakeShape(F32, {});
    317   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
    318       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
    319   auto inferred_status = ShapeInference::InferReduceWindowShape(
    320       matrix_shape, init_value_shape, window, to_apply);
    321 
    322   ASSERT_IS_OK(inferred_status.status());
    323   Shape inferred = inferred_status.ValueOrDie();
    324   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
    325 }
    326 
    327 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
    328   auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
    329       operand_shape_, select_program_shape_, window_, source_shape_,
    330       init_value_shape_, scatter_program_shape_);
    331   ASSERT_IS_OK(inferred_status_ok.status());
    332   Shape inferred = inferred_status_ok.ValueOrDie();
    333   ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
    334 }
    335 
    336 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
    337   Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
    338   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
    339       operand_shape_, select_program_shape_, window_, source_shape_fail,
    340       init_value_shape_, scatter_program_shape_);
    341   ASSERT_FALSE(inferred_status_fail.ok());
    342   ASSERT_THAT(inferred_status_fail.status().error_message(),
    343               HasSubstr("source shape does not match"));
    344 }
    345 
    346 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
    347   ProgramShape select_program_shape_fail =
    348       ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
    349   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
    350       operand_shape_, select_program_shape_fail, window_, source_shape_,
    351       init_value_shape_, scatter_program_shape_);
    352   ASSERT_FALSE(inferred_status_fail.ok());
    353   ASSERT_THAT(inferred_status_fail.status().error_message(),
    354               HasSubstr("select function must take 2 parameters"));
    355 }
    356 
    357 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
    358   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
    359       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
    360   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
    361       operand_shape_, select_program_shape_fail, window_, source_shape_,
    362       init_value_shape_, scatter_program_shape_);
    363   ASSERT_FALSE(inferred_status_fail.ok());
    364   ASSERT_THAT(inferred_status_fail.status().error_message(),
    365               HasSubstr("select function must have rank-0 PRED"));
    366 }
    367 
    368 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
    369   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
    370       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
    371   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
    372       operand_shape_, select_program_shape_fail, window_, source_shape_,
    373       init_value_shape_, scatter_program_shape_);
    374   ASSERT_FALSE(inferred_status_fail.ok());
    375   ASSERT_THAT(inferred_status_fail.status().error_message(),
    376               HasSubstr("select function's first parameter"));
    377 }
    378 
    379 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
    380   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
    381       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
    382   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
    383       operand_shape_, select_program_shape_fail, window_, source_shape_,
    384       init_value_shape_, scatter_program_shape_);
    385   ASSERT_FALSE(inferred_status_fail.ok());
    386   ASSERT_THAT(inferred_status_fail.status().error_message(),
    387               HasSubstr("select function's second parameter"));
    388 }
    389 
    390 TEST_F(ShapeInferenceTest, Convolve) {
    391   ConvolutionDimensionNumbers dnums;
    392 
    393   // Dimension order: batch, feature, x0, x1
    394   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
    395   dnums.set_input_batch_dimension(0);
    396   dnums.set_output_batch_dimension(0);
    397   dnums.set_input_feature_dimension(1);
    398   dnums.set_output_feature_dimension(1);
    399   dnums.add_input_spatial_dimensions(2);
    400   dnums.add_output_spatial_dimensions(2);
    401   dnums.add_input_spatial_dimensions(3);
    402   dnums.add_output_spatial_dimensions(3);
    403 
    404   // Dimension order: x1, batch, feature, x0
    405   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
    406   dnums.set_kernel_input_feature_dimension(2);
    407   dnums.set_kernel_output_feature_dimension(1);
    408   dnums.add_kernel_spatial_dimensions(3);
    409   dnums.add_kernel_spatial_dimensions(0);
    410 
    411   Window window;
    412   auto dim0 = window.add_dimensions();
    413   auto dim1 = window.add_dimensions();
    414   dim0->set_size(3);
    415   dim0->set_stride(2);
    416   dim0->set_padding_low(1);
    417   dim0->set_padding_high(1);
    418   dim0->set_window_dilation(1);
    419   dim0->set_base_dilation(1);
    420   dim1->set_size(2);
    421   dim1->set_stride(1);
    422   dim1->set_padding_low(0);
    423   dim1->set_padding_high(0);
    424   dim1->set_window_dilation(1);
    425   dim1->set_base_dilation(1);
    426   auto inferred_status =
    427       ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
    428   ASSERT_IS_OK(inferred_status.status());
    429   Shape inferred_shape = inferred_status.ValueOrDie();
    430   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
    431                                inferred_shape));
    432 }
    433 
    434 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
    435   ConvolutionDimensionNumbers dnums;
    436 
    437   // Dimension order: batch, feature, x0, x1
    438   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
    439   dnums.set_input_batch_dimension(0);
    440   dnums.set_output_batch_dimension(0);
    441   dnums.set_input_feature_dimension(1);
    442   dnums.set_output_feature_dimension(1);
    443   dnums.add_input_spatial_dimensions(2);
    444   dnums.add_output_spatial_dimensions(2);
    445   dnums.add_input_spatial_dimensions(3);
    446   dnums.add_output_spatial_dimensions(3);
    447 
    448   // Dimension order: x1, batch, feature, x0
    449   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
    450   dnums.set_kernel_input_feature_dimension(2);
    451   dnums.set_kernel_output_feature_dimension(1);
    452   dnums.add_kernel_spatial_dimensions(3);
    453   dnums.add_kernel_spatial_dimensions(0);
    454 
    455   Window window;
    456   auto dim0 = window.add_dimensions();
    457   dim0->set_size(3);
    458   dim0->set_stride(3);
    459   dim0->set_padding_low(0);
    460   dim0->set_padding_high(0);
    461   dim0->set_window_dilation(6);
    462   dim0->set_base_dilation(1);
    463 
    464   auto dim1 = window.add_dimensions();
    465   dim1->set_size(2);
    466   dim1->set_stride(1);
    467   dim1->set_padding_low(2);
    468   dim1->set_padding_high(1);
    469   dim1->set_window_dilation(2);
    470   dim1->set_base_dilation(1);
    471   auto inferred_status =
    472       ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
    473   ASSERT_IS_OK(inferred_status.status());
    474   Shape inferred_shape = inferred_status.ValueOrDie();
    475   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
    476                                inferred_shape));
    477 }
    478 
    479 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
    480   ConvolutionDimensionNumbers dnums;
    481 
    482   // Dimension order: batch, feature, x0, x1
    483   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
    484   dnums.set_input_batch_dimension(0);
    485   dnums.set_output_batch_dimension(0);
    486   dnums.set_input_feature_dimension(1);
    487   dnums.set_output_feature_dimension(1);
    488   dnums.add_input_spatial_dimensions(2);
    489   dnums.add_output_spatial_dimensions(2);
    490   dnums.add_input_spatial_dimensions(3);
    491   dnums.add_output_spatial_dimensions(3);
    492 
    493   // Dimension order: x1, batch, feature, x0
    494   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
    495   dnums.set_kernel_input_feature_dimension(2);
    496   dnums.set_kernel_output_feature_dimension(1);
    497   dnums.add_kernel_spatial_dimensions(3);
    498   dnums.add_kernel_spatial_dimensions(0);
    499 
    500   Window window;
    501   auto dim0 = window.add_dimensions();
    502   dim0->set_size(4);
    503   dim0->set_stride(3);
    504   dim0->set_padding_low(0);
    505   dim0->set_padding_high(0);
    506   dim0->set_window_dilation(1);
    507   dim0->set_base_dilation(6);
    508 
    509   auto dim1 = window.add_dimensions();
    510   dim1->set_size(2);
    511   dim1->set_stride(1);
    512   dim1->set_padding_low(2);
    513   dim1->set_padding_high(1);
    514   dim1->set_window_dilation(1);
    515   dim1->set_base_dilation(2);
    516   auto inferred_status =
    517       ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
    518   ASSERT_IS_OK(inferred_status.status());
    519   Shape inferred_shape = inferred_status.ValueOrDie();
    520   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
    521                                inferred_shape));
    522 }
    523 
    524 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
    525   // Dimension order for this test: batch, feature, x0, x1
    526   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
    527   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
    528 
    529   ConvolutionDimensionNumbers dnums;
    530   dnums.set_input_batch_dimension(3);
    531   dnums.set_output_batch_dimension(3);
    532   dnums.set_input_feature_dimension(2);
    533   dnums.set_output_feature_dimension(2);
    534   dnums.add_input_spatial_dimensions(0);
    535   dnums.add_output_spatial_dimensions(0);
    536   dnums.add_input_spatial_dimensions(1);
    537   dnums.add_output_spatial_dimensions(1);
    538   dnums.set_kernel_input_feature_dimension(0);  // duplicated with kernel_x0
    539   dnums.set_kernel_output_feature_dimension(3);
    540   dnums.add_kernel_spatial_dimensions(0);
    541   dnums.add_kernel_spatial_dimensions(1);
    542 
    543   Window window;
    544   auto dim0 = window.add_dimensions();
    545   auto dim1 = window.add_dimensions();
    546   dim0->set_size(2);
    547   dim0->set_stride(1);
    548   dim0->set_padding_low(0);
    549   dim0->set_padding_high(0);
    550   dim1->set_size(3);
    551   dim1->set_stride(2);
    552   dim1->set_padding_low(1);
    553   dim1->set_padding_high(1);
    554   auto inferred_status =
    555       ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
    556   ASSERT_FALSE(inferred_status.ok());
    557   ASSERT_THAT(inferred_status.status().error_message(),
    558               HasSubstr("each dimension exactly once"));
    559 }
    560 
    561 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
    562   Shape arg = ShapeUtil::MakeShape(F32, {20});
    563   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
    564   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
    565   EXPECT_IS_OK(inferred_status.status());
    566   Shape expected = ShapeUtil::MakeShape(S32, {20});
    567   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
    568 }
    569 
    570 TEST_F(ShapeInferenceTest, Map) {
    571   auto inferred_status_r1f32 = ShapeInference::InferMapShape(
    572       {&vector_32_, &vector_32_},
    573       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
    574   EXPECT_IS_OK(inferred_status_r1f32.status());
    575   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
    576 
    577   // It's OK to provide a single argument, as long as the applied arity matches
    578   // (this degenerates to a Map).
    579   auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
    580       {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
    581   EXPECT_IS_OK(inferred_status_r1f32_one.status());
    582   EXPECT_TRUE(
    583       ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
    584 
    585   auto inferred_status_r2s32 = ShapeInference::InferMapShape(
    586       {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
    587       ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
    588   EXPECT_IS_OK(inferred_status_r2s32.status());
    589   EXPECT_TRUE(
    590       ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
    591 
    592   auto no_args_error = ShapeInference::InferMapShape(
    593       {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
    594   ASSERT_FALSE(no_args_error.ok());
    595   ASSERT_THAT(no_args_error.status().error_message(),
    596               HasSubstr("expects at least one argument"));
    597 
    598   auto args_diff_shapes_error = ShapeInference::InferMapShape(
    599       {&vector_32_, &vector_64_},
    600       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
    601   ASSERT_FALSE(args_diff_shapes_error.ok());
    602   ASSERT_THAT(args_diff_shapes_error.status().error_message(),
    603               HasSubstr("requires all operands to have the same shape"));
    604 
    605   auto arity_error = ShapeInference::InferMapShape(
    606       {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
    607       {0});
    608   ASSERT_FALSE(arity_error.ok());
    609   ASSERT_THAT(arity_error.status().error_message(),
    610               HasSubstr("function arity must match"));
    611 
    612   auto output_shape_error = ShapeInference::InferMapShape(
    613       {&vector_32_, &vector_32_},
    614       ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
    615   ASSERT_FALSE(output_shape_error.ok());
    616   ASSERT_THAT(output_shape_error.status().error_message(),
    617               HasSubstr("result has to be a scalar"));
    618 
    619   auto param_shape_error = ShapeInference::InferMapShape(
    620       {&vector_32_, &vector_32_},
    621       ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
    622   ASSERT_FALSE(param_shape_error.ok());
    623   ASSERT_THAT(param_shape_error.status().error_message(),
    624               HasSubstr("parameter has to be a scalar"));
    625 
    626   auto param_element_type_error = ShapeInference::InferMapShape(
    627       {&vector_32_, &vector_32_},
    628       ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
    629   ASSERT_FALSE(param_element_type_error.ok());
    630   ASSERT_THAT(param_element_type_error.status().error_message(),
    631               HasSubstr("parameter type has to match argument"));
    632 
    633   Shape arg = ShapeUtil::MakeShape(F32, {20});
    634   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
    635   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
    636   EXPECT_IS_OK(inferred_status.status());
    637   EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
    638 
    639   auto inferred_status_error1 = ShapeInference::InferMapShape(
    640       {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
    641   ASSERT_FALSE(inferred_status_error1.ok());
    642   ASSERT_THAT(inferred_status_error1.status().error_message(),
    643               HasSubstr("arity must match number of arguments"));
    644 
    645   auto inferred_status_error2 = ShapeInference::InferMapShape(
    646       {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
    647   ASSERT_FALSE(inferred_status_error2.ok());
    648   ASSERT_THAT(inferred_status_error2.status().error_message(),
    649               HasSubstr("has to be a scalar"));
    650 
    651   auto inferred_status_error3 = ShapeInference::InferMapShape(
    652       {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
    653   ASSERT_FALSE(inferred_status_error3.ok());
    654   ASSERT_THAT(inferred_status_error3.status().error_message(),
    655               HasSubstr("has to be a scalar"));
    656 
    657   auto inferred_status_error5 = ShapeInference::InferMapShape(
    658       {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
    659   ASSERT_FALSE(inferred_status_error5.ok());
    660   ASSERT_THAT(inferred_status_error5.status().error_message(),
    661               HasSubstr("parameter type has to match argument"));
    662 }
    663 
    664 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
    665   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
    666                             /*dimensions_to_reduce=*/{0});
    667 }
    668 
    669 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
    670   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
    671                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    672                             /*dimensions_to_reduce=*/{0});
    673 }
    674 
    675 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
    676   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
    677                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    678                             /*dimensions_to_reduce=*/{1});
    679 }
    680 
    681 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
    682   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
    683                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    684                             /*dimensions_to_reduce=*/{0, 1});
    685 }
    686 
    687 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
    688   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
    689                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    690                             /*dimensions_to_reduce=*/{1, 2});
    691 }
    692 
    693 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
    694   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
    695                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    696                             /*dimensions_to_reduce=*/{0, 2});
    697 
    698   // Check that the order of dimensions_to_reduce doesn't matter.
    699   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
    700                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
    701                             /*dimensions_to_reduce=*/{2, 0});
    702 }
    703 
    704 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
    705   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
    706                             /*dimensions_to_reduce=*/{0, 1, 2});
    707 }
    708 
    709 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
    710   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
    711   auto inferred_status = ShapeInference::InferReduceShape(
    712       ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
    713       to_apply);
    714   EXPECT_FALSE(inferred_status.ok());
    715   EXPECT_THAT(inferred_status.status().error_message(),
    716               HasSubstr("out-of-bounds dimension"));
    717 }
    718 
    719 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
    720   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
    721   auto inferred_status =
    722       ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
    723                                        /*dimensions_to_reduce=*/{0}, to_apply);
    724   EXPECT_FALSE(inferred_status.ok());
    725   EXPECT_THAT(inferred_status.status().error_message(),
    726               HasSubstr("take 2 parameters"));
    727 }
    728 
    729 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
    730   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
    731   auto inferred_status =
    732       ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
    733                                        /*dimensions_to_reduce=*/{0}, to_apply);
    734   EXPECT_FALSE(inferred_status.ok());
    735   EXPECT_THAT(inferred_status.status().error_message(),
    736               HasSubstr("first parameter shape differs"));
    737 }
    738 
    739 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
    740   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    741   auto inferred_status =
    742       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
    743   ASSERT_IS_OK(inferred_status.status());
    744   Shape inferred = inferred_status.ValueOrDie();
    745   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
    746 }
    747 
    748 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
    749   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    750   auto inferred_status =
    751       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
    752   ASSERT_IS_OK(inferred_status.status());
    753   Shape inferred = inferred_status.ValueOrDie();
    754   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
    755 }
    756 
    757 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
    758   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    759   auto inferred_status =
    760       ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
    761   ASSERT_IS_OK(inferred_status.status());
    762   Shape inferred = inferred_status.ValueOrDie();
    763   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
    764 }
    765 
    766 TEST_F(ShapeInferenceTest, InferInvalidStride) {
    767   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    768   auto inferred_status =
    769       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
    770   ASSERT_FALSE(inferred_status.ok());
    771   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
    772             inferred_status.status().code());
    773 }
    774 
    775 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
    776   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
    777   auto inferred_status =
    778       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
    779   ASSERT_FALSE(inferred_status.ok());
    780   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
    781             inferred_status.status().code());
    782 }
    783 
    784 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
    785   Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
    786   auto inferred_status =
    787       ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
    788   ASSERT_TRUE(inferred_status.ok());
    789   Shape inferred = inferred_status.ValueOrDie();
    790   ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
    791 }
    792 
    793 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
    794   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
    795   auto inferred0_status =
    796       ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
    797   auto inferred1_status =
    798       ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
    799   ASSERT_IS_OK(inferred0_status.status());
    800   ASSERT_IS_OK(inferred1_status.status());
    801   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
    802   ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
    803 }
    804 
    805 TEST_F(ShapeInferenceTest, InferPowShape) {
    806   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    807   auto inferred_status =
    808       ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {});
    809   ASSERT_IS_OK(inferred_status.status());
    810   ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
    811 }
    812 
    813 TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
    814   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    815   auto inferred_status =
    816       ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {});
    817   ASSERT_IS_OK(inferred_status.status());
    818   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    819                                inferred_status.ValueOrDie()));
    820 }
    821 
    822 TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
    823   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    824   auto inferred_status =
    825       ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {});
    826   ASSERT_IS_OK(inferred_status.status());
    827   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    828                                inferred_status.ValueOrDie()));
    829 }
    830 
    831 TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
    832   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    833   auto inferred_status =
    834       ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {});
    835   ASSERT_IS_OK(inferred_status.status());
    836   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    837                                inferred_status.ValueOrDie()));
    838 }
    839 
    840 TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
    841   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    842   auto inferred_status =
    843       ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {});
    844   ASSERT_IS_OK(inferred_status.status());
    845   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    846                                inferred_status.ValueOrDie()));
    847 }
    848 
    849 TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
    850   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    851   auto inferred_status =
    852       ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {});
    853   ASSERT_IS_OK(inferred_status.status());
    854   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    855                                inferred_status.ValueOrDie()));
    856 }
    857 
    858 TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
    859   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
    860   auto inferred_status =
    861       ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {});
    862   ASSERT_IS_OK(inferred_status.status());
    863   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
    864                                inferred_status.ValueOrDie()));
    865 }
    866 
    867 TEST_F(ShapeInferenceTest, BroadcastScalar) {
    868   for (auto element_type : {F32, U32, S8}) {
    869     const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
    870     {  // no-op scalar broadcast
    871       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
    872       ASSERT_IS_OK(status.status());
    873       ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
    874     }
    875     const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
    876     {  // scalar -> 1d broadcast
    877       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
    878       ASSERT_IS_OK(status.status());
    879       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
    880     }
    881     {  // no-op 1d broadcast
    882       auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
    883       ASSERT_IS_OK(status.status());
    884       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
    885     }
    886     const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
    887     {  // scalar -> 2d broadcast
    888       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
    889       ASSERT_IS_OK(status.status());
    890       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
    891     }
    892     {  // 1d -> 2d broadcast
    893       auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
    894       ASSERT_IS_OK(status.status());
    895       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
    896     }
    897   }
    898 }
    899 
    900 // scalar <dot> vector: error
    901 TEST_F(ShapeInferenceTest, ScalarDotVector) {
    902   DotDimensionNumbers dot_dnums;
    903   dot_dnums.add_lhs_contracting_dimensions(1);
    904   dot_dnums.add_rhs_contracting_dimensions(0);
    905   auto inferred_status =
    906       ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
    907   ASSERT_FALSE(inferred_status.ok());
    908   ASSERT_THAT(inferred_status.status().error_message(),
    909               HasSubstr("dot only supports rank"));
    910 }
    911 
    912 // 3D <dot> 2D: error
    913 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
    914   DotDimensionNumbers dot_dnums;
    915   dot_dnums.add_lhs_contracting_dimensions(1);
    916   dot_dnums.add_rhs_contracting_dimensions(0);
    917   auto inferred_status = ShapeInference::InferDotOpShape(
    918       ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
    919   ASSERT_FALSE(inferred_status.ok());
    920   ASSERT_THAT(inferred_status.status().error_message(),
    921               HasSubstr("batch and contracting dimension number mismatch"));
    922 }
    923 
    924 // vector <dot> vector -> scalar
    925 TEST_F(ShapeInferenceTest, VectorDotVector) {
    926   DotDimensionNumbers dot_dnums;
    927   dot_dnums.add_lhs_contracting_dimensions(0);
    928   dot_dnums.add_rhs_contracting_dimensions(0);
    929   auto inferred_status =
    930       ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
    931   ASSERT_IS_OK(inferred_status.status());
    932   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
    933   auto inferred_status_mismatch =
    934       ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
    935   ASSERT_FALSE(inferred_status_mismatch.ok());
    936 }
    937 
    938 // matrix <dot> vector -> vector
    939 TEST_F(ShapeInferenceTest, MatrixDotVector) {
    940   DotDimensionNumbers dot_dnums;
    941   dot_dnums.add_lhs_contracting_dimensions(1);
    942   dot_dnums.add_rhs_contracting_dimensions(0);
    943   auto inferred_status =
    944       ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
    945   ASSERT_IS_OK(inferred_status.status());
    946   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
    947   auto inferred_status_mismatch =
    948       ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
    949   ASSERT_FALSE(inferred_status_mismatch.ok());
    950 }
    951 
    952 // vector <dot> matrix -> vector
    953 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
    954   DotDimensionNumbers dot_dnums;
    955   dot_dnums.add_lhs_contracting_dimensions(0);
    956   dot_dnums.add_rhs_contracting_dimensions(0);
    957   auto inferred_status =
    958       ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
    959   ASSERT_IS_OK(inferred_status.status());
    960   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
    961   auto inferred_status_mismatch =
    962       ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
    963   ASSERT_FALSE(inferred_status_mismatch.ok());
    964 }
    965 
    966 // matrix <dot> matrix -> matrix
    967 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
    968   DotDimensionNumbers dot_dnums;
    969   dot_dnums.add_lhs_contracting_dimensions(1);
    970   dot_dnums.add_rhs_contracting_dimensions(0);
    971   auto inferred_status_match =
    972       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
    973   ASSERT_IS_OK(inferred_status_match.status());
    974   ASSERT_TRUE(
    975       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
    976       << "inferred: "
    977       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
    978       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
    979   auto inferred_status_mismatch =
    980       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
    981   ASSERT_FALSE(inferred_status_mismatch.ok());
    982 }
    983 
    984 // BatchMatMul with two batch dimensions and one contracting dimension.
    985 TEST_F(ShapeInferenceTest, DotGeneral) {
    986   Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
    987   Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
    988   Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
    989 
    990   DotDimensionNumbers dot_dnums;
    991   dot_dnums.add_lhs_contracting_dimensions(3);
    992   dot_dnums.add_lhs_batch_dimensions(0);
    993   dot_dnums.add_lhs_batch_dimensions(1);
    994 
    995   dot_dnums.add_rhs_contracting_dimensions(2);
    996   dot_dnums.add_rhs_batch_dimensions(0);
    997   dot_dnums.add_rhs_batch_dimensions(1);
    998 
    999   auto inferred_status_match =
   1000       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1001   ASSERT_IS_OK(inferred_status_match.status());
   1002   ASSERT_TRUE(
   1003       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
   1004       << "inferred: "
   1005       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
   1006       << " expected: " << ShapeUtil::HumanString(output_shape);
   1007 }
   1008 
   1009 // BatchMatMul with two contracting dimensions fails.
   1010 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
   1011   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
   1012   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
   1013   Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
   1014 
   1015   DotDimensionNumbers dot_dnums;
   1016   dot_dnums.add_lhs_contracting_dimensions(2);
   1017   dot_dnums.add_lhs_contracting_dimensions(3);
   1018   dot_dnums.add_lhs_batch_dimensions(0);
   1019 
   1020   dot_dnums.add_rhs_contracting_dimensions(1);
   1021   dot_dnums.add_rhs_batch_dimensions(0);
   1022 
   1023   auto inferred_status =
   1024       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1025   ASSERT_FALSE(inferred_status.ok());
   1026   ASSERT_THAT(inferred_status.status().error_message(),
   1027               HasSubstr("must specify one contracting dimension for both "
   1028                         "lhs and rhs"));
   1029 }
   1030 
   1031 // BatchMatMul with different batch dimension sizes fails.
   1032 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
   1033   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
   1034   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
   1035 
   1036   DotDimensionNumbers dot_dnums;
   1037   dot_dnums.add_lhs_contracting_dimensions(2);
   1038   dot_dnums.add_lhs_batch_dimensions(0);
   1039 
   1040   dot_dnums.add_rhs_contracting_dimensions(1);
   1041   dot_dnums.add_rhs_batch_dimensions(0);
   1042 
   1043   auto inferred_status =
   1044       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1045   ASSERT_FALSE(inferred_status.ok());
   1046   ASSERT_THAT(inferred_status.status().error_message(),
   1047               HasSubstr("batch dimension numbers and sizes must match"));
   1048 }
   1049 
   1050 // BatchMatMul with different batch dimension numbers fails.
   1051 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
   1052   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
   1053   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
   1054 
   1055   DotDimensionNumbers dot_dnums;
   1056   dot_dnums.add_lhs_contracting_dimensions(2);
   1057   dot_dnums.add_lhs_batch_dimensions(0);
   1058 
   1059   dot_dnums.add_rhs_contracting_dimensions(0);
   1060   dot_dnums.add_rhs_batch_dimensions(1);
   1061 
   1062   auto inferred_status =
   1063       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1064   ASSERT_FALSE(inferred_status.ok());
   1065   ASSERT_THAT(inferred_status.status().error_message(),
   1066               HasSubstr("batch dimension numbers must precede non-batch"));
   1067 }
   1068 
   1069 // BatchMatMul with out-of-range dimension numbers fails.
   1070 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
   1071   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
   1072   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
   1073 
   1074   DotDimensionNumbers dot_dnums;
   1075   dot_dnums.add_lhs_contracting_dimensions(3);
   1076   dot_dnums.add_lhs_batch_dimensions(0);
   1077 
   1078   dot_dnums.add_rhs_contracting_dimensions(0);
   1079   dot_dnums.add_rhs_batch_dimensions(1);
   1080 
   1081   auto inferred_status =
   1082       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1083   ASSERT_FALSE(inferred_status.ok());
   1084   ASSERT_THAT(inferred_status.status().error_message(),
   1085               HasSubstr("A dimension number is out of range"));
   1086 }
   1087 
   1088 // BatchMatMul with non-unique dimension numbers fails.
   1089 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
   1090   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
   1091   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
   1092 
   1093   DotDimensionNumbers dot_dnums;
   1094   dot_dnums.add_lhs_contracting_dimensions(0);
   1095   dot_dnums.add_lhs_batch_dimensions(0);
   1096 
   1097   dot_dnums.add_rhs_contracting_dimensions(0);
   1098   dot_dnums.add_rhs_batch_dimensions(1);
   1099 
   1100   auto inferred_status =
   1101       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
   1102   ASSERT_FALSE(inferred_status.ok());
   1103   ASSERT_THAT(inferred_status.status().error_message(),
   1104               HasSubstr("A dimension number is not unique"));
   1105 }
   1106 
   1107 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
   1108   // Test variations of broadcasting a vector for a binary add with a
   1109   // matrix.
   1110   const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
   1111   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
   1112   const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
   1113 
   1114   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
   1115       BinaryOperation::BINOP_ADD, mat, vec8, {1});
   1116   ASSERT_IS_OK(inferred_status_match.status());
   1117   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
   1118 
   1119   auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
   1120       BinaryOperation::BINOP_ADD, mat, vec8, {0});
   1121   ASSERT_FALSE(inferred_status_mismatch.ok());
   1122 
   1123   inferred_status_match = ShapeInference::InferBinaryOpShape(
   1124       BinaryOperation::BINOP_ADD, mat, vec16, {0});
   1125   ASSERT_IS_OK(inferred_status_match.status());
   1126   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
   1127 
   1128   inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
   1129       BinaryOperation::BINOP_ADD, mat, vec16, {1});
   1130   ASSERT_FALSE(inferred_status_mismatch.ok());
   1131 }
   1132 
   1133 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
   1134   // Test variations of broadcasting a matrix for a binary add with a cube.
   1135   const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
   1136   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
   1137   const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
   1138   const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
   1139 
   1140   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
   1141       BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2});
   1142   ASSERT_IS_OK(inferred_status_match.status());
   1143   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
   1144 
   1145   inferred_status_match = ShapeInference::InferBinaryOpShape(
   1146       BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2});
   1147   ASSERT_IS_OK(inferred_status_match.status());
   1148   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
   1149 
   1150   inferred_status_match = ShapeInference::InferBinaryOpShape(
   1151       BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1});
   1152   ASSERT_IS_OK(inferred_status_match.status());
   1153   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
   1154 }
   1155 
   1156 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
   1157   // Test various errors with the broadcast argument.
   1158   const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
   1159   const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
   1160   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
   1161   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
   1162   const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
   1163 
   1164   // "magical" broadcast rejected
   1165   auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
   1166       BinaryOperation::BINOP_ADD, tensor, vec8, {});
   1167   ASSERT_FALSE(inferred_status_error1.ok());
   1168   ASSERT_THAT(inferred_status_error1.status().error_message(),
   1169               HasSubstr("automatic"));
   1170 
   1171   // broadcast_dimension out of bounds for tensor's rank
   1172   auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
   1173       BinaryOperation::BINOP_ADD, tensor, vec8, {3});
   1174   ASSERT_FALSE(inferred_status_error2.ok());
   1175   ASSERT_THAT(inferred_status_error2.status().error_message(),
   1176               ContainsRegex("broadcast dimension number .* too large"));
   1177 
   1178   // broadcast_dimension doesn't match corresponding dimension
   1179   auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
   1180       BinaryOperation::BINOP_ADD, tensor, vec8, {0});
   1181   ASSERT_FALSE(inferred_status_error3.ok());
   1182   ASSERT_THAT(inferred_status_error3.status().error_message(),
   1183               HasSubstr("broadcast dimension 0 mismatch"));
   1184 
   1185   // broadcast_dimensions list too long
   1186   auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
   1187       BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
   1188   ASSERT_FALSE(inferred_status_error4.ok());
   1189   ASSERT_THAT(inferred_status_error4.status().error_message(),
   1190               HasSubstr("size of broadcast_dimensions has to match"));
   1191 
   1192   // there's a dimension above the rank of the tensor
   1193   auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
   1194       BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
   1195   ASSERT_FALSE(inferred_status_error5.ok());
   1196   ASSERT_THAT(inferred_status_error5.status().error_message(),
   1197               ContainsRegex("broadcast dimension number .* too large"));
   1198 
   1199   // broadcasting dimensions don't match in this order
   1200   auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
   1201       BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
   1202   ASSERT_FALSE(inferred_status_error6.ok());
   1203   ASSERT_THAT(inferred_status_error6.status().error_message(),
   1204               HasSubstr("broadcast dimension 0 mismatch"));
   1205 
   1206   // The following two tests make sure that broadcasting dimensions are listed
   1207   // in a proper (strictly increasing) order, even if the lower-rank array
   1208   // matches the higher-rank array in many different ways.
   1209   auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
   1210       BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
   1211   ASSERT_FALSE(inferred_status_error7.ok());
   1212   ASSERT_THAT(inferred_status_error7.status().error_message(),
   1213               HasSubstr("broadcast dimensions order is wrong"));
   1214 
   1215   auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
   1216       BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
   1217   ASSERT_FALSE(inferred_status_error8.ok());
   1218   ASSERT_THAT(inferred_status_error8.status().error_message(),
   1219               HasSubstr("broadcast dimensions order is wrong"));
   1220 }
   1221 
   1222 // Tests for the while instruction with proper shapes.
   1223 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
   1224   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
   1225   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
   1226   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
   1227   auto inferred_status =
   1228       ShapeInference::InferWhileShape(cond, body, result_shape);
   1229   ASSERT_IS_OK(inferred_status.status());
   1230   Shape inferred = inferred_status.ValueOrDie();
   1231   ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
   1232 }
   1233 
   1234 // Tests for the while instruction with wrong shapes.
   1235 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
   1236   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
   1237   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
   1238   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
   1239 
   1240   auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
   1241   auto inferred_status_error1 =
   1242       ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
   1243   ASSERT_FALSE(inferred_status_error1.ok());
   1244   ASSERT_THAT(inferred_status_error1.status().error_message(),
   1245               HasSubstr("condition must take 1 arguments"));
   1246 
   1247   auto bad_shape_2 =
   1248       ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
   1249   auto inferred_status_error2 =
   1250       ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
   1251   ASSERT_FALSE(inferred_status_error2.ok());
   1252   ASSERT_THAT(inferred_status_error2.status().error_message(),
   1253               HasSubstr("body must take 1 arguments"));
   1254 
   1255   auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
   1256   auto inferred_status_error3 =
   1257       ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
   1258   ASSERT_FALSE(inferred_status_error3.ok());
   1259   ASSERT_THAT(inferred_status_error3.status().error_message(),
   1260               HasSubstr("condition must return a boolean"));
   1261 
   1262   auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
   1263   auto inferred_status_error4 =
   1264       ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
   1265   ASSERT_FALSE(inferred_status_error4.ok());
   1266   ASSERT_THAT(inferred_status_error4.status().error_message(),
   1267               HasSubstr("parameter of condition and body"));
   1268 }
   1269 
   1270 // Tests for the concatenate instruction with proper shapes.
   1271 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
   1272   auto inferred_status_1 = ShapeInference::InferConcatOpShape(
   1273       {&vector_32_, &vector_64_}, /*dimension=*/0);
   1274   ASSERT_IS_OK(inferred_status_1.status());
   1275   Shape inferred_1 = inferred_status_1.ValueOrDie();
   1276   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
   1277 
   1278   auto inferred_status_2 = ShapeInference::InferConcatOpShape(
   1279       {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
   1280   ASSERT_IS_OK(inferred_status_2.status());
   1281   Shape inferred_2 = inferred_status_2.ValueOrDie();
   1282   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
   1283 
   1284   auto inferred_status_3 = ShapeInference::InferConcatOpShape(
   1285       {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
   1286   ASSERT_IS_OK(inferred_status_3.status());
   1287   Shape inferred_3 = inferred_status_3.ValueOrDie();
   1288   ASSERT_TRUE(
   1289       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
   1290 }
   1291 
   1292 // Tests for the concatenate instruction with wrong shapes.
   1293 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
   1294   auto inferred_status_error1 =
   1295       ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
   1296   ASSERT_FALSE(inferred_status_error1.ok());
   1297   ASSERT_THAT(inferred_status_error1.status().error_message(),
   1298               HasSubstr("Concatenate expects at least one argument"));
   1299 
   1300   auto inferred_status_error2 =
   1301       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
   1302   ASSERT_FALSE(inferred_status_error2.ok());
   1303   ASSERT_THAT(inferred_status_error2.status().error_message(),
   1304               HasSubstr("dimension to concatenate along out of bounds: -1"));
   1305 
   1306   auto inferred_status_error3 =
   1307       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
   1308   ASSERT_FALSE(inferred_status_error3.ok());
   1309   ASSERT_THAT(inferred_status_error3.status().error_message(),
   1310               HasSubstr("dimension to concatenate along out of bounds: 1"));
   1311 
   1312   Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
   1313   auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
   1314       {&vector_32_, &tuple}, /*dimension=*/0);
   1315   ASSERT_FALSE(inferred_status_error4.ok());
   1316   ASSERT_THAT(
   1317       inferred_status_error4.status().error_message(),
   1318       HasSubstr("Expected non-tuple argument for operand of concatenation."));
   1319 
   1320   const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
   1321   auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
   1322       {&vector_32_, &vector_s32}, /*dimension=*/0);
   1323   ASSERT_FALSE(inferred_status_error5.ok());
   1324   ASSERT_THAT(
   1325       inferred_status_error5.status().error_message(),
   1326       HasSubstr("cannot concatenate arrays with different element types"));
   1327 
   1328   auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
   1329       {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
   1330   ASSERT_FALSE(inferred_status_error6.ok());
   1331   ASSERT_THAT(inferred_status_error6.status().error_message(),
   1332               HasSubstr("cannot concatenate arrays that differ in "
   1333                         "dimensions other than the one being "
   1334                         "concatenated"));
   1335 }
   1336 
   1337 TEST_F(ShapeInferenceTest, Pad) {
   1338   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
   1339   Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
   1340   // Padding for dimension 0: {low: 0, high: 2, interior: 3}
   1341   // Padding for dimension 1: {low: 1, high: 5, interior: 0}
   1342   PaddingConfig padding_config;
   1343   auto dimension0 = padding_config.add_dimensions();
   1344   dimension0->set_edge_padding_low(0);
   1345   dimension0->set_edge_padding_high(2);
   1346   dimension0->set_interior_padding(3);
   1347   auto dimension1 = padding_config.add_dimensions();
   1348   dimension1->set_edge_padding_low(1);
   1349   dimension1->set_edge_padding_high(5);
   1350   dimension1->set_interior_padding(0);
   1351 
   1352   auto inferred_status = ShapeInference::InferPadShape(
   1353       input_shape, padding_value_shape, padding_config);
   1354   ASSERT_IS_OK(inferred_status.status());
   1355   Shape inferred_shape = inferred_status.ValueOrDie();
   1356   ASSERT_TRUE(
   1357       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
   1358 }
   1359 
   1360 TEST_F(ShapeInferenceTest, Reverse) {
   1361   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
   1362 
   1363   auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
   1364   ASSERT_IS_OK(inferred_status.status());
   1365   Shape inferred_shape = inferred_status.ValueOrDie();
   1366   ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
   1367 }
   1368 
   1369 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
   1370   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
   1371 
   1372   auto inferred_status_error0 =
   1373       ShapeInference::InferReverseShape(input_shape, {0, 2});
   1374   ASSERT_FALSE(inferred_status_error0.ok());
   1375   ASSERT_THAT(inferred_status_error0.status().error_message(),
   1376               HasSubstr("out-of-bounds"));
   1377 
   1378   auto inferred_status_error1 =
   1379       ShapeInference::InferReverseShape(input_shape, {0, -1});
   1380   ASSERT_FALSE(inferred_status_error1.ok());
   1381   ASSERT_THAT(inferred_status_error1.status().error_message(),
   1382               HasSubstr("out-of-bounds"));
   1383 
   1384   auto inferred_status_error2 =
   1385       ShapeInference::InferReverseShape(input_shape, {0, 0});
   1386   ASSERT_FALSE(inferred_status_error2.ok());
   1387   ASSERT_THAT(inferred_status_error2.status().error_message(),
   1388               HasSubstr("duplicated"));
   1389 
   1390   Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
   1391   auto inferred_status_error3 =
   1392       ShapeInference::InferReverseShape(tuple_shape, {0});
   1393   ASSERT_FALSE(inferred_status_error3.ok());
   1394   ASSERT_THAT(inferred_status_error3.status().error_message(),
   1395               HasSubstr("Expected non-tuple argument"));
   1396 }
   1397 
   1398 TEST_F(ShapeInferenceTest, Call) {
   1399   auto inferred_status0 =
   1400       ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
   1401   EXPECT_IS_OK(inferred_status0.status());
   1402   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
   1403 
   1404   auto inferred_status1 = ShapeInference::InferCallShape(
   1405       {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
   1406       ShapeUtil::MakeProgramShape(
   1407           {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
   1408   EXPECT_IS_OK(inferred_status1.status());
   1409   EXPECT_TRUE(
   1410       ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
   1411 
   1412   auto inferred_status_error0 = ShapeInference::InferCallShape(
   1413       {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
   1414   EXPECT_FALSE(inferred_status_error0.ok());
   1415   EXPECT_THAT(inferred_status_error0.status().error_message(),
   1416               HasSubstr("arity must match"));
   1417 
   1418   auto inferred_status_error1 = ShapeInference::InferCallShape(
   1419       {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
   1420   EXPECT_FALSE(inferred_status_error1.ok());
   1421   EXPECT_THAT(inferred_status_error1.status().error_message(),
   1422               HasSubstr("arity must match"));
   1423 
   1424   auto inferred_status_error2 = ShapeInference::InferCallShape(
   1425       {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
   1426   EXPECT_FALSE(inferred_status_error2.ok());
   1427   EXPECT_THAT(inferred_status_error2.status().error_message(),
   1428               HasSubstr("parameter must match argument"));
   1429 }
   1430 
   1431 TEST_F(ShapeInferenceTest, Transpose) {
   1432   Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
   1433   auto inferred_shape_and_status =
   1434       ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
   1435   EXPECT_IS_OK(inferred_shape_and_status);
   1436   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
   1437   EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
   1438                                     ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
   1439 }
   1440 
   1441 TEST_F(ShapeInferenceTest, Conditional) {
   1442   auto inferred_status0 = ShapeInference::InferConditionalShape(
   1443       pred_, vector_32_, vector_64_,
   1444       ShapeUtil::MakeProgramShape({vector_32_}, f32_),
   1445       ShapeUtil::MakeProgramShape({vector_64_}, f32_));
   1446   EXPECT_IS_OK(inferred_status0.status());
   1447   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
   1448 
   1449   auto inferred_status1 = ShapeInference::InferConditionalShape(
   1450       pred_, matrix_32_48_, vector_32_,
   1451       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
   1452       ShapeUtil::MakeProgramShape({vector_32_}, vector_64_));
   1453   EXPECT_IS_OK(inferred_status1.status());
   1454   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
   1455 
   1456   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
   1457   auto inferred_status2 = ShapeInference::InferConditionalShape(
   1458       pred_, matrix_32_48_, tuple_f32_v32,
   1459       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
   1460       ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_));
   1461   EXPECT_IS_OK(inferred_status2.status());
   1462   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
   1463 
   1464   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
   1465       s32_, vector_32_, vector_64_,
   1466       ShapeUtil::MakeProgramShape({vector_32_}, f32_),
   1467       ShapeUtil::MakeProgramShape({vector_64_}, f32_));
   1468   EXPECT_FALSE(inferred_status_error0.ok());
   1469   EXPECT_THAT(inferred_status_error0.status().error_message(),
   1470               HasSubstr("predicate must be a boolean"));
   1471 
   1472   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
   1473       pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_,
   1474       ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
   1475       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_));
   1476   EXPECT_FALSE(inferred_status_error1.ok());
   1477   EXPECT_THAT(inferred_status_error1.status().error_message(),
   1478               HasSubstr("true_computation must take 1 argument"));
   1479 
   1480   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
   1481       pred_, vector_32_, vector_64_,
   1482       ShapeUtil::MakeProgramShape({vector_64_}, f32_),
   1483       ShapeUtil::MakeProgramShape({vector_64_}, f32_));
   1484   EXPECT_FALSE(inferred_status_error2.ok());
   1485   EXPECT_THAT(inferred_status_error2.status().error_message(),
   1486               HasSubstr("true_operand must match the shape of the only "
   1487                         "parameter of true_computation"));
   1488 
   1489   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
   1490       pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
   1491       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
   1492       ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_));
   1493   EXPECT_FALSE(inferred_status_error3.ok());
   1494   EXPECT_THAT(inferred_status_error3.status().error_message(),
   1495               HasSubstr("false_computation must take 1 argument"));
   1496 
   1497   auto inferred_status_error4 = ShapeInference::InferConditionalShape(
   1498       pred_, vector_32_, vector_64_,
   1499       ShapeUtil::MakeProgramShape({vector_32_}, f32_),
   1500       ShapeUtil::MakeProgramShape({vector_32_}, f32_));
   1501   EXPECT_FALSE(inferred_status_error4.ok());
   1502   EXPECT_THAT(inferred_status_error4.status().error_message(),
   1503               HasSubstr("false_operand must match the shape of the only "
   1504                         "parameter of false_computation"));
   1505 
   1506   auto inferred_status_error5 = ShapeInference::InferConditionalShape(
   1507       pred_, vector_32_, vector_64_,
   1508       ShapeUtil::MakeProgramShape({vector_32_}, f32_),
   1509       ShapeUtil::MakeProgramShape({vector_64_}, vector_32_));
   1510   EXPECT_FALSE(inferred_status_error5.ok());
   1511   EXPECT_THAT(inferred_status_error5.status().error_message(),
   1512               HasSubstr("the result of true_computation and false_computation "
   1513                         "must have the same shape"));
   1514 }
   1515 
   1516 TEST_F(ShapeInferenceTest, BadSlice) {
   1517   auto arg = ShapeUtil::MakeShape(F32, {4});
   1518   StatusOr<Shape> statusor =
   1519       ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
   1520   ASSERT_FALSE(statusor.ok());
   1521 
   1522   LOG(INFO) << statusor.status();
   1523 
   1524   EXPECT_THAT(statusor.status().error_message(),
   1525               HasSubstr("less than or equal to dimension size"))
   1526       << statusor.status();
   1527   EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
   1528       << statusor.status();
   1529 }
   1530 
   1531 class GatherShapeInferenceTest : public ShapeInferenceTest {
   1532  protected:
   1533   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
   1534   const Shape s64_4d_tensor_10_9_8_7_1_ =
   1535       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
   1536   const Shape s64_4d_tensor_10_9_8_7_5_ =
   1537       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
   1538   const Shape f32_5d_tensor_50_49_48_47_46_ =
   1539       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
   1540   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
   1541       {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
   1542 };
   1543 
   1544 TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
   1545   TF_ASSERT_OK_AND_ASSIGN(
   1546       Shape gather_shape,
   1547       ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
   1548                                        HloInstruction::MakeGatherDimNumbers(
   1549                                            /*output_window_dims=*/{0},
   1550                                            /*elided_window_dims=*/{1},
   1551                                            /*gather_dims_to_operand_dims=*/{1}),
   1552                                        /*window_bounds=*/{64, 1}));
   1553   EXPECT_TRUE(
   1554       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
   1555       << ShapeUtil::HumanString(gather_shape);
   1556 }
   1557 
   1558 TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
   1559   TF_ASSERT_OK_AND_ASSIGN(
   1560       Shape gather_shape,
   1561       ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
   1562                                        HloInstruction::MakeGatherDimNumbers(
   1563                                            /*output_window_dims=*/{1},
   1564                                            /*elided_window_dims=*/{0},
   1565                                            /*gather_dims_to_operand_dims=*/{0}),
   1566                                        /*window_bounds=*/{1, 48}));
   1567   EXPECT_TRUE(
   1568       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
   1569       << ShapeUtil::HumanString(gather_shape);
   1570 }
   1571 
   1572 TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
   1573   TF_ASSERT_OK_AND_ASSIGN(
   1574       Shape gather_shape,
   1575       ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
   1576                                        HloInstruction::MakeGatherDimNumbers(
   1577                                            /*output_window_dims=*/{4},
   1578                                            /*elided_window_dims=*/{0},
   1579                                            /*gather_dims_to_operand_dims=*/{0}),
   1580                                        /*window_bounds=*/{1, 48}));
   1581   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
   1582                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
   1583       << ShapeUtil::HumanString(gather_shape);
   1584 }
   1585 
   1586 TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
   1587   TF_ASSERT_OK_AND_ASSIGN(
   1588       Shape gather_shape,
   1589       ShapeInference::InferGatherShape(
   1590           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1591           HloInstruction::MakeGatherDimNumbers(
   1592               /*output_window_dims=*/{4, 5, 6, 7, 8},
   1593               /*elided_window_dims=*/{},
   1594               /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1595           /*window_bounds=*/{30, 29, 28, 27, 26}));
   1596   EXPECT_TRUE(ShapeUtil::Equal(
   1597       gather_shape,
   1598       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
   1599       << ShapeUtil::HumanString(gather_shape);
   1600 }
   1601 
   1602 TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
   1603   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1604       tuple_shape_, s64_vector_32_,
   1605       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
   1606                                            /*elided_window_dims=*/{1},
   1607                                            /*gather_dims_to_operand_dims=*/{1}),
   1608       /*window_bounds=*/{64, 1});
   1609   ASSERT_FALSE(statusor.ok());
   1610   EXPECT_THAT(statusor.status().error_message(),
   1611               HasSubstr("Expected non-tuple argument for input"))
   1612       << statusor.status();
   1613 }
   1614 
   1615 TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
   1616   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1617       s64_vector_32_, tuple_shape_,
   1618       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
   1619                                            /*elided_window_dims=*/{1},
   1620                                            /*gather_dims_to_operand_dims=*/{1}),
   1621       /*window_bounds=*/{64, 1});
   1622   ASSERT_FALSE(statusor.ok());
   1623   EXPECT_THAT(statusor.status().error_message(),
   1624               HasSubstr("Expected non-tuple argument for gather indices"))
   1625       << statusor.status();
   1626 }
   1627 
   1628 TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
   1629   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1630       s64_vector_32_, s32_,
   1631       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
   1632                                            /*elided_window_dims=*/{1},
   1633                                            /*gather_dims_to_operand_dims=*/{1}),
   1634       /*window_bounds=*/{64, 1});
   1635   ASSERT_FALSE(statusor.ok());
   1636   EXPECT_THAT(statusor.status().error_message(),
   1637               HasSubstr("Gather indices parameter must at least of rank 1"))
   1638       << statusor.status();
   1639 }
   1640 
   1641 TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
   1642   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1643       s64_vector_32_, vector_32_,
   1644       HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
   1645                                            /*elided_window_dims=*/{1},
   1646                                            /*gather_dims_to_operand_dims=*/{1}),
   1647       /*window_bounds=*/{64, 1});
   1648   ASSERT_FALSE(statusor.ok());
   1649   EXPECT_THAT(statusor.status().error_message(),
   1650               HasSubstr("Gather indices parameter must be an integral tensor"))
   1651       << statusor.status();
   1652 }
   1653 
   1654 TEST_F(GatherShapeInferenceTest,
   1655        InvalidGatherDimNumbers_NonAscendingWindowIndices) {
   1656   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1657       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1658       HloInstruction::MakeGatherDimNumbers(
   1659           /*output_window_dims=*/{4, 5, 6, 8, 7},
   1660           /*elided_window_dims=*/{},
   1661           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1662       /*window_bounds=*/{30, 29, 28, 27, 26});
   1663   ASSERT_FALSE(statusor.ok());
   1664   EXPECT_THAT(
   1665       statusor.status().error_message(),
   1666       HasSubstr("Output window dimensions in gather op must be ascending"))
   1667       << statusor.status();
   1668 }
   1669 
   1670 TEST_F(GatherShapeInferenceTest,
   1671        InvalidGatherDimNumbers_RepeatedWindowIndices) {
   1672   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1673       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1674       HloInstruction::MakeGatherDimNumbers(
   1675           /*output_window_dims=*/{4, 5, 6, 7, 7},
   1676           /*elided_window_dims=*/{},
   1677           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1678       /*window_bounds=*/{30, 29, 28, 27, 26});
   1679   ASSERT_FALSE(statusor.ok());
   1680   EXPECT_THAT(
   1681       statusor.status().error_message(),
   1682       HasSubstr("Output window dimensions in gather op must not repeat"))
   1683       << statusor.status();
   1684 }
   1685 
   1686 TEST_F(GatherShapeInferenceTest,
   1687        InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
   1688   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1689       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1690       HloInstruction::MakeGatherDimNumbers(
   1691           /*output_window_dims=*/{4, 5, 99, 100, 101},
   1692           /*elided_window_dims=*/{},
   1693           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1694       /*window_bounds=*/{30, 29, 28, 27, 26});
   1695   ASSERT_FALSE(statusor.ok());
   1696   EXPECT_THAT(statusor.status().error_message(),
   1697               HasSubstr("Window index 2 in gather op is out of bounds"))
   1698       << statusor.status();
   1699 }
   1700 
   1701 TEST_F(GatherShapeInferenceTest,
   1702        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
   1703   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1704       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1705       HloInstruction::MakeGatherDimNumbers(
   1706           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1707           /*elided_window_dims=*/{4},
   1708           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1709       /*window_bounds=*/{30, 29, 28, 27, 26});
   1710   ASSERT_FALSE(statusor.ok());
   1711   EXPECT_THAT(
   1712       statusor.status().error_message(),
   1713       HasSubstr("All components of the window index in a gather op must either "
   1714                 "be a output window index or explicitly elided"))
   1715       << statusor.status();
   1716 }
   1717 
   1718 TEST_F(GatherShapeInferenceTest,
   1719        InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
   1720   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1721       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1722       HloInstruction::MakeGatherDimNumbers(
   1723           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1724           /*elided_window_dims=*/{0, 1, 2, 3, 19},
   1725           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1726       /*window_bounds=*/{30, 29, 28, 27, 26});
   1727   ASSERT_FALSE(statusor.ok());
   1728   EXPECT_THAT(statusor.status().error_message(),
   1729               HasSubstr("Invalid elided_window_dims set in gather op; valid "
   1730                         "range is [0, 5), got: 19"))
   1731       << statusor.status();
   1732 }
   1733 
   1734 TEST_F(GatherShapeInferenceTest,
   1735        InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
   1736   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1737       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1738       HloInstruction::MakeGatherDimNumbers(
   1739           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1740           /*elided_window_dims=*/{0, 1, 2, 3, 3},
   1741           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1742       /*window_bounds=*/{30, 29, 28, 27, 26});
   1743   ASSERT_FALSE(statusor.ok());
   1744   EXPECT_THAT(
   1745       statusor.status().error_message(),
   1746       HasSubstr(
   1747           "Repeated dimensions not allowed in elided_window_dims in gather op"))
   1748       << statusor.status();
   1749 }
   1750 
   1751 TEST_F(GatherShapeInferenceTest,
   1752        InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
   1753   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1754       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1755       HloInstruction::MakeGatherDimNumbers(
   1756           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1757           /*elided_window_dims=*/{},
   1758           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
   1759       /*window_bounds=*/{30, 29, 28, 27, 26});
   1760   ASSERT_FALSE(statusor.ok());
   1761   EXPECT_THAT(
   1762       statusor.status().error_message(),
   1763       HasSubstr(
   1764           "There must be exactly as many elements in "
   1765           "gather_dims_to_operand_dims "
   1766           "as there are elements in the last dimension of %gather_indices"))
   1767       << statusor.status();
   1768 }
   1769 
   1770 TEST_F(GatherShapeInferenceTest,
   1771        InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
   1772   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1773       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1774       HloInstruction::MakeGatherDimNumbers(
   1775           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1776           /*elided_window_dims=*/{},
   1777           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
   1778       /*window_bounds=*/{30, 29, 28, 27, 26});
   1779   ASSERT_FALSE(statusor.ok());
   1780   EXPECT_THAT(
   1781       statusor.status().error_message(),
   1782       HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
   1783                 "[0, 5), got: 4->7"))
   1784       << statusor.status();
   1785 }
   1786 
   1787 TEST_F(GatherShapeInferenceTest,
   1788        InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
   1789   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1790       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1791       HloInstruction::MakeGatherDimNumbers(
   1792           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1793           /*elided_window_dims=*/{},
   1794           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
   1795       /*window_bounds=*/{30, 29, 28, 27, 26});
   1796   ASSERT_FALSE(statusor.ok());
   1797   EXPECT_THAT(
   1798       statusor.status().error_message(),
   1799       HasSubstr(
   1800           "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
   1801       << statusor.status();
   1802 }
   1803 
   1804 TEST_F(GatherShapeInferenceTest,
   1805        InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
   1806   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1807       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1808       HloInstruction::MakeGatherDimNumbers(
   1809           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1810           /*elided_window_dims=*/{2, 1},
   1811           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1812       /*window_bounds=*/{1, 1, 28, 27, 26});
   1813   ASSERT_FALSE(statusor.ok());
   1814   EXPECT_THAT(statusor.status().error_message(),
   1815               HasSubstr("elided_window_dims in gather op must be sorted"))
   1816       << statusor.status();
   1817 }
   1818 
   1819 TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
   1820   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1821       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1822       HloInstruction::MakeGatherDimNumbers(
   1823           /*output_window_dims=*/{4, 5, 6, 7},
   1824           /*elided_window_dims=*/{2},
   1825           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1826       /*window_bounds=*/{30, 29, 1, 300, 26});
   1827   ASSERT_FALSE(statusor.ok());
   1828   EXPECT_THAT(statusor.status().error_message(),
   1829               HasSubstr("Window bound at index 3 in gather op is out of range, "
   1830                         "must be within [0, 48), got 300"))
   1831       << statusor.status();
   1832 }
   1833 
   1834 TEST_F(GatherShapeInferenceTest,
   1835        InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
   1836   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1837       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1838       HloInstruction::MakeGatherDimNumbers(
   1839           /*output_window_dims=*/{4, 5, 6, 7, 8},
   1840           /*elided_window_dims=*/{},
   1841           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1842       /*window_bounds=*/{30, 29, 28, 26});
   1843   ASSERT_FALSE(statusor.ok());
   1844   EXPECT_THAT(
   1845       statusor.status().error_message(),
   1846       HasSubstr(
   1847           "Gather op must have one window bound for every input dimension"))
   1848       << statusor.status();
   1849 }
   1850 
   1851 TEST_F(GatherShapeInferenceTest,
   1852        InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
   1853   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
   1854       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
   1855       HloInstruction::MakeGatherDimNumbers(
   1856           /*output_window_dims=*/{4, 5, 6, 7},
   1857           /*elided_window_dims=*/{1},
   1858           /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1859       /*window_bounds=*/{30, 29, 28, 26, 20});
   1860   ASSERT_FALSE(statusor.ok());
   1861   EXPECT_THAT(statusor.status().error_message(),
   1862               HasSubstr("Gather op can only elide window indices with bound 1, "
   1863                         "but bound is 29 for index 1 at position 0"))
   1864       << statusor.status();
   1865 }
   1866 
   1867 }  // namespace
   1868 }  // namespace xla
   1869