1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/shape_inference.h" 17 18 #include <string> 19 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/compiler/xla/test.h" 22 #include "tensorflow/compiler/xla/test_helpers.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 27 namespace xla { 28 namespace { 29 30 using ::tensorflow::gtl::ArraySlice; 31 using ::testing::ContainsRegex; 32 using ::testing::HasSubstr; 33 34 class ShapeInferenceTest : public ::testing::Test { 35 protected: 36 // Some handy scalar shapes. 37 const Shape s32_ = ShapeUtil::MakeShape(S32, {}); 38 const Shape f32_ = ShapeUtil::MakeShape(F32, {}); 39 const Shape f64_ = ShapeUtil::MakeShape(F64, {}); 40 const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); 41 42 // Some handy vector and matrix shapes of F32 type. 43 // Suffix: vector_length_, matrix_rows_cols_ 44 const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32}); 45 const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64}); 46 const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48}); 47 const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64}); 48 const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48}); 49 50 // Some handy S32 arrays. 51 const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64}); 52 }; 53 54 // Subclass for testing InferReduceShape. 55 class ReduceShapeInferenceTest : public ShapeInferenceTest { 56 protected: 57 // Helper that runs reduce shape inference with the input 'arg' and given 58 // dimensions to reduce, and checks the inferred shape is as expected. The 59 // element type here is hard-coded to F32. 60 void ExpectInferredReduceShape( 61 const Shape& expected_inferred_shape, const Shape& arg, 62 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { 63 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); 64 auto inferred_status = ShapeInference::InferReduceShape( 65 arg, f32_, dimensions_to_reduce, to_apply); 66 EXPECT_IS_OK(inferred_status.status()); 67 EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, 68 inferred_status.ValueOrDie())); 69 } 70 }; 71 72 // Subclass for testing InferSelectAndScatterShape. 73 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { 74 protected: 75 SelectAndScatterShapeInferenceTest() { 76 operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16}); 77 source_shape_ = ShapeUtil::MakeShape(F32, {4, 8}); 78 WindowDimension dim; 79 dim.set_size(2); 80 dim.set_stride(2); 81 dim.set_padding_low(0); 82 dim.set_padding_high(0); 83 dim.set_window_dilation(1); 84 dim.set_base_dilation(1); 85 *window_.add_dimensions() = dim; 86 *window_.add_dimensions() = dim; 87 init_value_shape_ = ShapeUtil::MakeShape(F32, {}); 88 select_program_shape_ = ShapeUtil::MakeProgramShape( 89 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); 90 scatter_program_shape_ = ShapeUtil::MakeProgramShape( 91 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); 92 } 93 94 Shape operand_shape_; 95 Shape source_shape_; 96 Window window_; 97 Shape init_value_shape_; 98 ProgramShape select_program_shape_; 99 ProgramShape scatter_program_shape_; 100 }; 101 102 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { 103 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 104 auto inferred_status = ShapeInference::InferUnaryOpShape( 105 UnaryOperation::UNOP_NEGATE, matrix_shape); 106 ASSERT_IS_OK(inferred_status.status()); 107 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); 108 } 109 110 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { 111 Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); 112 auto inferred_status = ShapeInference::InferTernaryOpShape( 113 TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); 114 ASSERT_IS_OK(inferred_status.status()); 115 ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); 116 } 117 118 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { 119 auto inferred_status = ShapeInference::InferTernaryOpShape( 120 TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); 121 ASSERT_IS_OK(inferred_status.status()); 122 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 123 } 124 125 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { 126 auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); 127 auto inferred_status = ShapeInference::InferTernaryOpShape( 128 TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); 129 ASSERT_IS_OK(inferred_status.status()); 130 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 131 } 132 133 TEST_F(ShapeInferenceTest, SelectBadShapes) { 134 auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( 135 TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); 136 ASSERT_FALSE(inferred_status_error1.ok()); 137 ASSERT_THAT(inferred_status_error1.status().error_message(), 138 HasSubstr("operands to select must be the same shape")); 139 140 auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( 141 TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); 142 ASSERT_FALSE(inferred_status_error2.ok()); 143 ASSERT_THAT(inferred_status_error2.status().error_message(), 144 HasSubstr("pred operand must have PRED")); 145 146 auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( 147 TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), 148 matrix_64_48_, matrix_64_48_); 149 ASSERT_FALSE(inferred_status_error3.ok()); 150 ASSERT_THAT(inferred_status_error3.status().error_message(), 151 HasSubstr("with non-scalar predicate with dimensionality")); 152 153 // Tuples have a TUPLE element type and cannot be the pred of a select. 154 auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( 155 TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), 156 ShapeUtil::MakeTupleShape({f32_, f32_}), 157 ShapeUtil::MakeTupleShape({f32_, f32_})); 158 ASSERT_FALSE(inferred_status_error4.ok()); 159 ASSERT_THAT(inferred_status_error4.status().error_message(), 160 HasSubstr("pred operand must have PRED element type")); 161 } 162 163 TEST_F(ShapeInferenceTest, ClampAllMatrix) { 164 auto inferred_status = ShapeInference::InferTernaryOpShape( 165 TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, 166 matrix_64_48_); 167 ASSERT_IS_OK(inferred_status.status()); 168 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 169 } 170 171 TEST_F(ShapeInferenceTest, ClampAllScalar) { 172 auto inferred_status = ShapeInference::InferTernaryOpShape( 173 TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); 174 ASSERT_IS_OK(inferred_status.status()); 175 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); 176 } 177 178 TEST_F(ShapeInferenceTest, ClampMinScalar) { 179 auto inferred_status = ShapeInference::InferTernaryOpShape( 180 TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); 181 ASSERT_IS_OK(inferred_status.status()); 182 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 183 } 184 185 TEST_F(ShapeInferenceTest, ClampMaxScalar) { 186 auto inferred_status = ShapeInference::InferTernaryOpShape( 187 TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); 188 ASSERT_IS_OK(inferred_status.status()); 189 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 190 } 191 192 TEST_F(ShapeInferenceTest, ClampOperandScalar) { 193 auto inferred_status = ShapeInference::InferTernaryOpShape( 194 TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); 195 ASSERT_IS_OK(inferred_status.status()); 196 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 197 } 198 199 TEST_F(ShapeInferenceTest, ClampMinMatrix) { 200 auto inferred_status = ShapeInference::InferTernaryOpShape( 201 TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); 202 ASSERT_IS_OK(inferred_status.status()); 203 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 204 } 205 206 TEST_F(ShapeInferenceTest, ClampMaxMatrix) { 207 auto inferred_status = ShapeInference::InferTernaryOpShape( 208 TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); 209 ASSERT_IS_OK(inferred_status.status()); 210 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 211 } 212 213 TEST_F(ShapeInferenceTest, ClampOperandMatrix) { 214 auto inferred_status = ShapeInference::InferTernaryOpShape( 215 TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); 216 ASSERT_IS_OK(inferred_status.status()); 217 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); 218 } 219 220 TEST_F(ShapeInferenceTest, ClampBadShapes) { 221 // Type mismatch 222 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 223 TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) 224 .ok()); 225 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 226 TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) 227 .ok()); 228 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 229 TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) 230 .ok()); 231 // Dimension mismatch 232 ASSERT_FALSE( 233 ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, 234 vector_64_, vector_32_, vector_32_) 235 .ok()); 236 ASSERT_FALSE( 237 ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, 238 vector_32_, vector_64_, vector_32_) 239 .ok()); 240 ASSERT_FALSE( 241 ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, 242 vector_32_, vector_32_, vector_64_) 243 .ok()); 244 // Dimension mismatch, where one operand is a scalar 245 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 246 TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) 247 .ok()); 248 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 249 TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) 250 .ok()); 251 ASSERT_FALSE(ShapeInference::InferTernaryOpShape( 252 TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) 253 .ok()); 254 } 255 256 TEST_F(ShapeInferenceTest, Complex) { 257 auto complex_shape = [&](const Shape& lhs, const Shape& rhs, 258 const tensorflow::gtl::ArraySlice<int64>& bcast) { 259 return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, 260 lhs, rhs, bcast); 261 }; 262 // Inputs must be FP. 263 ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); 264 ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); 265 // Component types must match. 266 ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); 267 // Only F32->C64 supported. 268 ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); 269 // Validate correct uses. 270 Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); 271 TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); 272 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {}))); 273 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); 274 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); 275 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {})); 276 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); 277 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); 278 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); 279 280 Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); 281 TF_ASSERT_OK_AND_ASSIGN(result, 282 complex_shape(vector_64_, matrix_32_64_, {1})); 283 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); 284 TF_ASSERT_OK_AND_ASSIGN(result, 285 complex_shape(matrix_32_64_, vector_64_, {1})); 286 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); 287 TF_ASSERT_OK_AND_ASSIGN(result, 288 complex_shape(matrix_32_64_, matrix_32_64_, {})); 289 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); 290 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); 291 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); 292 } 293 294 TEST_F(ShapeInferenceTest, VariadicOpTuplify) { 295 StatusOr<Shape> result = ShapeInference::InferVariadicOpShape( 296 VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); 297 ASSERT_IS_OK(result.status()); 298 ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), 299 ShapeUtil::MakeTupleShape({s32_, f32_}))); 300 } 301 302 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { 303 Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8}); 304 Window window; 305 WindowDimension dim; 306 dim.set_size(2); 307 dim.set_stride(2); 308 dim.set_padding_low(0); 309 dim.set_padding_high(0); 310 dim.set_window_dilation(1); 311 dim.set_base_dilation(1); 312 *window.add_dimensions() = dim; 313 *window.add_dimensions() = dim; 314 Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2}); 315 Shape init_value_shape = ShapeUtil::MakeShape(F32, {}); 316 Shape float_scalar = ShapeUtil::MakeShape(F32, {}); 317 ProgramShape to_apply = ShapeUtil::MakeProgramShape( 318 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); 319 auto inferred_status = ShapeInference::InferReduceWindowShape( 320 matrix_shape, init_value_shape, window, to_apply); 321 322 ASSERT_IS_OK(inferred_status.status()); 323 Shape inferred = inferred_status.ValueOrDie(); 324 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred)); 325 } 326 327 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) { 328 auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape( 329 operand_shape_, select_program_shape_, window_, source_shape_, 330 init_value_shape_, scatter_program_shape_); 331 ASSERT_IS_OK(inferred_status_ok.status()); 332 Shape inferred = inferred_status_ok.ValueOrDie(); 333 ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred)); 334 } 335 336 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { 337 Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); 338 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( 339 operand_shape_, select_program_shape_, window_, source_shape_fail, 340 init_value_shape_, scatter_program_shape_); 341 ASSERT_FALSE(inferred_status_fail.ok()); 342 ASSERT_THAT(inferred_status_fail.status().error_message(), 343 HasSubstr("source shape does not match")); 344 } 345 346 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { 347 ProgramShape select_program_shape_fail = 348 ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_); 349 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( 350 operand_shape_, select_program_shape_fail, window_, source_shape_, 351 init_value_shape_, scatter_program_shape_); 352 ASSERT_FALSE(inferred_status_fail.ok()); 353 ASSERT_THAT(inferred_status_fail.status().error_message(), 354 HasSubstr("select function must take 2 parameters")); 355 } 356 357 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { 358 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( 359 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); 360 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( 361 operand_shape_, select_program_shape_fail, window_, source_shape_, 362 init_value_shape_, scatter_program_shape_); 363 ASSERT_FALSE(inferred_status_fail.ok()); 364 ASSERT_THAT(inferred_status_fail.status().error_message(), 365 HasSubstr("select function must have rank-0 PRED")); 366 } 367 368 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { 369 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( 370 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); 371 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( 372 operand_shape_, select_program_shape_fail, window_, source_shape_, 373 init_value_shape_, scatter_program_shape_); 374 ASSERT_FALSE(inferred_status_fail.ok()); 375 ASSERT_THAT(inferred_status_fail.status().error_message(), 376 HasSubstr("select function's first parameter")); 377 } 378 379 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { 380 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( 381 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_); 382 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( 383 operand_shape_, select_program_shape_fail, window_, source_shape_, 384 init_value_shape_, scatter_program_shape_); 385 ASSERT_FALSE(inferred_status_fail.ok()); 386 ASSERT_THAT(inferred_status_fail.status().error_message(), 387 HasSubstr("select function's second parameter")); 388 } 389 390 TEST_F(ShapeInferenceTest, Convolve) { 391 ConvolutionDimensionNumbers dnums; 392 393 // Dimension order: batch, feature, x0, x1 394 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); 395 dnums.set_input_batch_dimension(0); 396 dnums.set_output_batch_dimension(0); 397 dnums.set_input_feature_dimension(1); 398 dnums.set_output_feature_dimension(1); 399 dnums.add_input_spatial_dimensions(2); 400 dnums.add_output_spatial_dimensions(2); 401 dnums.add_input_spatial_dimensions(3); 402 dnums.add_output_spatial_dimensions(3); 403 404 // Dimension order: x1, batch, feature, x0 405 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); 406 dnums.set_kernel_input_feature_dimension(2); 407 dnums.set_kernel_output_feature_dimension(1); 408 dnums.add_kernel_spatial_dimensions(3); 409 dnums.add_kernel_spatial_dimensions(0); 410 411 Window window; 412 auto dim0 = window.add_dimensions(); 413 auto dim1 = window.add_dimensions(); 414 dim0->set_size(3); 415 dim0->set_stride(2); 416 dim0->set_padding_low(1); 417 dim0->set_padding_high(1); 418 dim0->set_window_dilation(1); 419 dim0->set_base_dilation(1); 420 dim1->set_size(2); 421 dim1->set_stride(1); 422 dim1->set_padding_low(0); 423 dim1->set_padding_high(0); 424 dim1->set_window_dilation(1); 425 dim1->set_base_dilation(1); 426 auto inferred_status = 427 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); 428 ASSERT_IS_OK(inferred_status.status()); 429 Shape inferred_shape = inferred_status.ValueOrDie(); 430 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), 431 inferred_shape)); 432 } 433 434 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { 435 ConvolutionDimensionNumbers dnums; 436 437 // Dimension order: batch, feature, x0, x1 438 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); 439 dnums.set_input_batch_dimension(0); 440 dnums.set_output_batch_dimension(0); 441 dnums.set_input_feature_dimension(1); 442 dnums.set_output_feature_dimension(1); 443 dnums.add_input_spatial_dimensions(2); 444 dnums.add_output_spatial_dimensions(2); 445 dnums.add_input_spatial_dimensions(3); 446 dnums.add_output_spatial_dimensions(3); 447 448 // Dimension order: x1, batch, feature, x0 449 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); 450 dnums.set_kernel_input_feature_dimension(2); 451 dnums.set_kernel_output_feature_dimension(1); 452 dnums.add_kernel_spatial_dimensions(3); 453 dnums.add_kernel_spatial_dimensions(0); 454 455 Window window; 456 auto dim0 = window.add_dimensions(); 457 dim0->set_size(3); 458 dim0->set_stride(3); 459 dim0->set_padding_low(0); 460 dim0->set_padding_high(0); 461 dim0->set_window_dilation(6); 462 dim0->set_base_dilation(1); 463 464 auto dim1 = window.add_dimensions(); 465 dim1->set_size(2); 466 dim1->set_stride(1); 467 dim1->set_padding_low(2); 468 dim1->set_padding_high(1); 469 dim1->set_window_dilation(2); 470 dim1->set_base_dilation(1); 471 auto inferred_status = 472 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); 473 ASSERT_IS_OK(inferred_status.status()); 474 Shape inferred_shape = inferred_status.ValueOrDie(); 475 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), 476 inferred_shape)); 477 } 478 479 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { 480 ConvolutionDimensionNumbers dnums; 481 482 // Dimension order: batch, feature, x0, x1 483 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); 484 dnums.set_input_batch_dimension(0); 485 dnums.set_output_batch_dimension(0); 486 dnums.set_input_feature_dimension(1); 487 dnums.set_output_feature_dimension(1); 488 dnums.add_input_spatial_dimensions(2); 489 dnums.add_output_spatial_dimensions(2); 490 dnums.add_input_spatial_dimensions(3); 491 dnums.add_output_spatial_dimensions(3); 492 493 // Dimension order: x1, batch, feature, x0 494 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); 495 dnums.set_kernel_input_feature_dimension(2); 496 dnums.set_kernel_output_feature_dimension(1); 497 dnums.add_kernel_spatial_dimensions(3); 498 dnums.add_kernel_spatial_dimensions(0); 499 500 Window window; 501 auto dim0 = window.add_dimensions(); 502 dim0->set_size(4); 503 dim0->set_stride(3); 504 dim0->set_padding_low(0); 505 dim0->set_padding_high(0); 506 dim0->set_window_dilation(1); 507 dim0->set_base_dilation(6); 508 509 auto dim1 = window.add_dimensions(); 510 dim1->set_size(2); 511 dim1->set_stride(1); 512 dim1->set_padding_low(2); 513 dim1->set_padding_high(1); 514 dim1->set_window_dilation(1); 515 dim1->set_base_dilation(2); 516 auto inferred_status = 517 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); 518 ASSERT_IS_OK(inferred_status.status()); 519 Shape inferred_shape = inferred_status.ValueOrDie(); 520 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), 521 inferred_shape)); 522 } 523 524 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { 525 // Dimension order for this test: batch, feature, x0, x1 526 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); 527 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); 528 529 ConvolutionDimensionNumbers dnums; 530 dnums.set_input_batch_dimension(3); 531 dnums.set_output_batch_dimension(3); 532 dnums.set_input_feature_dimension(2); 533 dnums.set_output_feature_dimension(2); 534 dnums.add_input_spatial_dimensions(0); 535 dnums.add_output_spatial_dimensions(0); 536 dnums.add_input_spatial_dimensions(1); 537 dnums.add_output_spatial_dimensions(1); 538 dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 539 dnums.set_kernel_output_feature_dimension(3); 540 dnums.add_kernel_spatial_dimensions(0); 541 dnums.add_kernel_spatial_dimensions(1); 542 543 Window window; 544 auto dim0 = window.add_dimensions(); 545 auto dim1 = window.add_dimensions(); 546 dim0->set_size(2); 547 dim0->set_stride(1); 548 dim0->set_padding_low(0); 549 dim0->set_padding_high(0); 550 dim1->set_size(3); 551 dim1->set_stride(2); 552 dim1->set_padding_low(1); 553 dim1->set_padding_high(1); 554 auto inferred_status = 555 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); 556 ASSERT_FALSE(inferred_status.ok()); 557 ASSERT_THAT(inferred_status.status().error_message(), 558 HasSubstr("each dimension exactly once")); 559 } 560 561 TEST_F(ShapeInferenceTest, MapThatChangesElementType) { 562 Shape arg = ShapeUtil::MakeShape(F32, {20}); 563 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); 564 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); 565 EXPECT_IS_OK(inferred_status.status()); 566 Shape expected = ShapeUtil::MakeShape(S32, {20}); 567 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); 568 } 569 570 TEST_F(ShapeInferenceTest, Map) { 571 auto inferred_status_r1f32 = ShapeInference::InferMapShape( 572 {&vector_32_, &vector_32_}, 573 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); 574 EXPECT_IS_OK(inferred_status_r1f32.status()); 575 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); 576 577 // It's OK to provide a single argument, as long as the applied arity matches 578 // (this degenerates to a Map). 579 auto inferred_status_r1f32_one = ShapeInference::InferMapShape( 580 {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); 581 EXPECT_IS_OK(inferred_status_r1f32_one.status()); 582 EXPECT_TRUE( 583 ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); 584 585 auto inferred_status_r2s32 = ShapeInference::InferMapShape( 586 {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, 587 ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); 588 EXPECT_IS_OK(inferred_status_r2s32.status()); 589 EXPECT_TRUE( 590 ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); 591 592 auto no_args_error = ShapeInference::InferMapShape( 593 {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); 594 ASSERT_FALSE(no_args_error.ok()); 595 ASSERT_THAT(no_args_error.status().error_message(), 596 HasSubstr("expects at least one argument")); 597 598 auto args_diff_shapes_error = ShapeInference::InferMapShape( 599 {&vector_32_, &vector_64_}, 600 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); 601 ASSERT_FALSE(args_diff_shapes_error.ok()); 602 ASSERT_THAT(args_diff_shapes_error.status().error_message(), 603 HasSubstr("requires all operands to have the same shape")); 604 605 auto arity_error = ShapeInference::InferMapShape( 606 {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), 607 {0}); 608 ASSERT_FALSE(arity_error.ok()); 609 ASSERT_THAT(arity_error.status().error_message(), 610 HasSubstr("function arity must match")); 611 612 auto output_shape_error = ShapeInference::InferMapShape( 613 {&vector_32_, &vector_32_}, 614 ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); 615 ASSERT_FALSE(output_shape_error.ok()); 616 ASSERT_THAT(output_shape_error.status().error_message(), 617 HasSubstr("result has to be a scalar")); 618 619 auto param_shape_error = ShapeInference::InferMapShape( 620 {&vector_32_, &vector_32_}, 621 ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); 622 ASSERT_FALSE(param_shape_error.ok()); 623 ASSERT_THAT(param_shape_error.status().error_message(), 624 HasSubstr("parameter has to be a scalar")); 625 626 auto param_element_type_error = ShapeInference::InferMapShape( 627 {&vector_32_, &vector_32_}, 628 ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); 629 ASSERT_FALSE(param_element_type_error.ok()); 630 ASSERT_THAT(param_element_type_error.status().error_message(), 631 HasSubstr("parameter type has to match argument")); 632 633 Shape arg = ShapeUtil::MakeShape(F32, {20}); 634 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); 635 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); 636 EXPECT_IS_OK(inferred_status.status()); 637 EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); 638 639 auto inferred_status_error1 = ShapeInference::InferMapShape( 640 {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); 641 ASSERT_FALSE(inferred_status_error1.ok()); 642 ASSERT_THAT(inferred_status_error1.status().error_message(), 643 HasSubstr("arity must match number of arguments")); 644 645 auto inferred_status_error2 = ShapeInference::InferMapShape( 646 {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); 647 ASSERT_FALSE(inferred_status_error2.ok()); 648 ASSERT_THAT(inferred_status_error2.status().error_message(), 649 HasSubstr("has to be a scalar")); 650 651 auto inferred_status_error3 = ShapeInference::InferMapShape( 652 {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); 653 ASSERT_FALSE(inferred_status_error3.ok()); 654 ASSERT_THAT(inferred_status_error3.status().error_message(), 655 HasSubstr("has to be a scalar")); 656 657 auto inferred_status_error5 = ShapeInference::InferMapShape( 658 {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); 659 ASSERT_FALSE(inferred_status_error5.ok()); 660 ASSERT_THAT(inferred_status_error5.status().error_message(), 661 HasSubstr("parameter type has to match argument")); 662 } 663 664 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { 665 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}), 666 /*dimensions_to_reduce=*/{0}); 667 } 668 669 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) { 670 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}), 671 ShapeUtil::MakeShape(F32, {2, 3, 4}), 672 /*dimensions_to_reduce=*/{0}); 673 } 674 675 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) { 676 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}), 677 ShapeUtil::MakeShape(F32, {2, 3, 4}), 678 /*dimensions_to_reduce=*/{1}); 679 } 680 681 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) { 682 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}), 683 ShapeUtil::MakeShape(F32, {2, 3, 4}), 684 /*dimensions_to_reduce=*/{0, 1}); 685 } 686 687 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) { 688 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}), 689 ShapeUtil::MakeShape(F32, {2, 3, 4}), 690 /*dimensions_to_reduce=*/{1, 2}); 691 } 692 693 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) { 694 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), 695 ShapeUtil::MakeShape(F32, {2, 3, 4}), 696 /*dimensions_to_reduce=*/{0, 2}); 697 698 // Check that the order of dimensions_to_reduce doesn't matter. 699 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), 700 ShapeUtil::MakeShape(F32, {2, 3, 4}), 701 /*dimensions_to_reduce=*/{2, 0}); 702 } 703 704 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { 705 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}), 706 /*dimensions_to_reduce=*/{0, 1, 2}); 707 } 708 709 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { 710 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); 711 auto inferred_status = ShapeInference::InferReduceShape( 712 ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, 713 to_apply); 714 EXPECT_FALSE(inferred_status.ok()); 715 EXPECT_THAT(inferred_status.status().error_message(), 716 HasSubstr("out-of-bounds dimension")); 717 } 718 719 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { 720 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); 721 auto inferred_status = 722 ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, 723 /*dimensions_to_reduce=*/{0}, to_apply); 724 EXPECT_FALSE(inferred_status.ok()); 725 EXPECT_THAT(inferred_status.status().error_message(), 726 HasSubstr("take 2 parameters")); 727 } 728 729 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { 730 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); 731 auto inferred_status = 732 ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, 733 /*dimensions_to_reduce=*/{0}, to_apply); 734 EXPECT_FALSE(inferred_status.ok()); 735 EXPECT_THAT(inferred_status.status().error_message(), 736 HasSubstr("first parameter shape differs")); 737 } 738 739 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { 740 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 741 auto inferred_status = 742 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); 743 ASSERT_IS_OK(inferred_status.status()); 744 Shape inferred = inferred_status.ValueOrDie(); 745 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); 746 } 747 748 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { 749 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 750 auto inferred_status = 751 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); 752 ASSERT_IS_OK(inferred_status.status()); 753 Shape inferred = inferred_status.ValueOrDie(); 754 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); 755 } 756 757 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { 758 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 759 auto inferred_status = 760 ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); 761 ASSERT_IS_OK(inferred_status.status()); 762 Shape inferred = inferred_status.ValueOrDie(); 763 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred)); 764 } 765 766 TEST_F(ShapeInferenceTest, InferInvalidStride) { 767 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 768 auto inferred_status = 769 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); 770 ASSERT_FALSE(inferred_status.ok()); 771 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, 772 inferred_status.status().code()); 773 } 774 775 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { 776 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); 777 auto inferred_status = 778 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); 779 ASSERT_FALSE(inferred_status.ok()); 780 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, 781 inferred_status.status().code()); 782 } 783 784 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { 785 Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); 786 auto inferred_status = 787 ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); 788 ASSERT_TRUE(inferred_status.ok()); 789 Shape inferred = inferred_status.ValueOrDie(); 790 ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); 791 } 792 793 TEST_F(ShapeInferenceTest, InferConstIndexShape) { 794 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); 795 auto inferred0_status = 796 ShapeInference::InferGetTupleElementShape(tuple_shape, 0); 797 auto inferred1_status = 798 ShapeInference::InferGetTupleElementShape(tuple_shape, 1); 799 ASSERT_IS_OK(inferred0_status.status()); 800 ASSERT_IS_OK(inferred1_status.status()); 801 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie())); 802 ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); 803 } 804 805 TEST_F(ShapeInferenceTest, InferPowShape) { 806 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 807 auto inferred_status = 808 ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); 809 ASSERT_IS_OK(inferred_status.status()); 810 ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); 811 } 812 813 TEST_F(ShapeInferenceTest, InferCompareShapeEq) { 814 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 815 auto inferred_status = 816 ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); 817 ASSERT_IS_OK(inferred_status.status()); 818 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 819 inferred_status.ValueOrDie())); 820 } 821 822 TEST_F(ShapeInferenceTest, InferCompareShapeGe) { 823 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 824 auto inferred_status = 825 ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); 826 ASSERT_IS_OK(inferred_status.status()); 827 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 828 inferred_status.ValueOrDie())); 829 } 830 831 TEST_F(ShapeInferenceTest, InferCompareShapeGt) { 832 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 833 auto inferred_status = 834 ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); 835 ASSERT_IS_OK(inferred_status.status()); 836 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 837 inferred_status.ValueOrDie())); 838 } 839 840 TEST_F(ShapeInferenceTest, InferCompareShapeLe) { 841 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 842 auto inferred_status = 843 ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); 844 ASSERT_IS_OK(inferred_status.status()); 845 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 846 inferred_status.ValueOrDie())); 847 } 848 849 TEST_F(ShapeInferenceTest, InferCompareShapeLt) { 850 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 851 auto inferred_status = 852 ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); 853 ASSERT_IS_OK(inferred_status.status()); 854 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 855 inferred_status.ValueOrDie())); 856 } 857 858 TEST_F(ShapeInferenceTest, InferCompareShapeNe) { 859 auto ten_floats = ShapeUtil::MakeShape(F32, {10}); 860 auto inferred_status = 861 ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); 862 ASSERT_IS_OK(inferred_status.status()); 863 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), 864 inferred_status.ValueOrDie())); 865 } 866 867 TEST_F(ShapeInferenceTest, BroadcastScalar) { 868 for (auto element_type : {F32, U32, S8}) { 869 const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {}); 870 { // no-op scalar broadcast 871 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {}); 872 ASSERT_IS_OK(status.status()); 873 ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie())); 874 } 875 const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3}); 876 { // scalar -> 1d broadcast 877 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3}); 878 ASSERT_IS_OK(status.status()); 879 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); 880 } 881 { // no-op 1d broadcast 882 auto status = ShapeInference::InferBroadcastShape(oned_shape, {}); 883 ASSERT_IS_OK(status.status()); 884 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); 885 } 886 const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3}); 887 { // scalar -> 2d broadcast 888 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3}); 889 ASSERT_IS_OK(status.status()); 890 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); 891 } 892 { // 1d -> 2d broadcast 893 auto status = ShapeInference::InferBroadcastShape(oned_shape, {2}); 894 ASSERT_IS_OK(status.status()); 895 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); 896 } 897 } 898 } 899 900 // scalar <dot> vector: error 901 TEST_F(ShapeInferenceTest, ScalarDotVector) { 902 DotDimensionNumbers dot_dnums; 903 dot_dnums.add_lhs_contracting_dimensions(1); 904 dot_dnums.add_rhs_contracting_dimensions(0); 905 auto inferred_status = 906 ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); 907 ASSERT_FALSE(inferred_status.ok()); 908 ASSERT_THAT(inferred_status.status().error_message(), 909 HasSubstr("dot only supports rank")); 910 } 911 912 // 3D <dot> 2D: error 913 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { 914 DotDimensionNumbers dot_dnums; 915 dot_dnums.add_lhs_contracting_dimensions(1); 916 dot_dnums.add_rhs_contracting_dimensions(0); 917 auto inferred_status = ShapeInference::InferDotOpShape( 918 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); 919 ASSERT_FALSE(inferred_status.ok()); 920 ASSERT_THAT(inferred_status.status().error_message(), 921 HasSubstr("batch and contracting dimension number mismatch")); 922 } 923 924 // vector <dot> vector -> scalar 925 TEST_F(ShapeInferenceTest, VectorDotVector) { 926 DotDimensionNumbers dot_dnums; 927 dot_dnums.add_lhs_contracting_dimensions(0); 928 dot_dnums.add_rhs_contracting_dimensions(0); 929 auto inferred_status = 930 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); 931 ASSERT_IS_OK(inferred_status.status()); 932 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); 933 auto inferred_status_mismatch = 934 ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); 935 ASSERT_FALSE(inferred_status_mismatch.ok()); 936 } 937 938 // matrix <dot> vector -> vector 939 TEST_F(ShapeInferenceTest, MatrixDotVector) { 940 DotDimensionNumbers dot_dnums; 941 dot_dnums.add_lhs_contracting_dimensions(1); 942 dot_dnums.add_rhs_contracting_dimensions(0); 943 auto inferred_status = 944 ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); 945 ASSERT_IS_OK(inferred_status.status()); 946 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); 947 auto inferred_status_mismatch = 948 ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); 949 ASSERT_FALSE(inferred_status_mismatch.ok()); 950 } 951 952 // vector <dot> matrix -> vector 953 TEST_F(ShapeInferenceTest, VectorDotMatrix) { 954 DotDimensionNumbers dot_dnums; 955 dot_dnums.add_lhs_contracting_dimensions(0); 956 dot_dnums.add_rhs_contracting_dimensions(0); 957 auto inferred_status = 958 ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); 959 ASSERT_IS_OK(inferred_status.status()); 960 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); 961 auto inferred_status_mismatch = 962 ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); 963 ASSERT_FALSE(inferred_status_mismatch.ok()); 964 } 965 966 // matrix <dot> matrix -> matrix 967 TEST_F(ShapeInferenceTest, MatrixDotMatrix) { 968 DotDimensionNumbers dot_dnums; 969 dot_dnums.add_lhs_contracting_dimensions(1); 970 dot_dnums.add_rhs_contracting_dimensions(0); 971 auto inferred_status_match = 972 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); 973 ASSERT_IS_OK(inferred_status_match.status()); 974 ASSERT_TRUE( 975 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) 976 << "inferred: " 977 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) 978 << " expected: " << ShapeUtil::HumanString(matrix_64_48_); 979 auto inferred_status_mismatch = 980 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); 981 ASSERT_FALSE(inferred_status_mismatch.ok()); 982 } 983 984 // BatchMatMul with two batch dimensions and one contracting dimension. 985 TEST_F(ShapeInferenceTest, DotGeneral) { 986 Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); 987 Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); 988 Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); 989 990 DotDimensionNumbers dot_dnums; 991 dot_dnums.add_lhs_contracting_dimensions(3); 992 dot_dnums.add_lhs_batch_dimensions(0); 993 dot_dnums.add_lhs_batch_dimensions(1); 994 995 dot_dnums.add_rhs_contracting_dimensions(2); 996 dot_dnums.add_rhs_batch_dimensions(0); 997 dot_dnums.add_rhs_batch_dimensions(1); 998 999 auto inferred_status_match = 1000 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1001 ASSERT_IS_OK(inferred_status_match.status()); 1002 ASSERT_TRUE( 1003 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) 1004 << "inferred: " 1005 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) 1006 << " expected: " << ShapeUtil::HumanString(output_shape); 1007 } 1008 1009 // BatchMatMul with two contracting dimensions fails. 1010 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { 1011 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); 1012 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); 1013 Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); 1014 1015 DotDimensionNumbers dot_dnums; 1016 dot_dnums.add_lhs_contracting_dimensions(2); 1017 dot_dnums.add_lhs_contracting_dimensions(3); 1018 dot_dnums.add_lhs_batch_dimensions(0); 1019 1020 dot_dnums.add_rhs_contracting_dimensions(1); 1021 dot_dnums.add_rhs_batch_dimensions(0); 1022 1023 auto inferred_status = 1024 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1025 ASSERT_FALSE(inferred_status.ok()); 1026 ASSERT_THAT(inferred_status.status().error_message(), 1027 HasSubstr("must specify one contracting dimension for both " 1028 "lhs and rhs")); 1029 } 1030 1031 // BatchMatMul with different batch dimension sizes fails. 1032 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { 1033 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); 1034 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); 1035 1036 DotDimensionNumbers dot_dnums; 1037 dot_dnums.add_lhs_contracting_dimensions(2); 1038 dot_dnums.add_lhs_batch_dimensions(0); 1039 1040 dot_dnums.add_rhs_contracting_dimensions(1); 1041 dot_dnums.add_rhs_batch_dimensions(0); 1042 1043 auto inferred_status = 1044 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1045 ASSERT_FALSE(inferred_status.ok()); 1046 ASSERT_THAT(inferred_status.status().error_message(), 1047 HasSubstr("batch dimension numbers and sizes must match")); 1048 } 1049 1050 // BatchMatMul with different batch dimension numbers fails. 1051 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { 1052 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); 1053 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); 1054 1055 DotDimensionNumbers dot_dnums; 1056 dot_dnums.add_lhs_contracting_dimensions(2); 1057 dot_dnums.add_lhs_batch_dimensions(0); 1058 1059 dot_dnums.add_rhs_contracting_dimensions(0); 1060 dot_dnums.add_rhs_batch_dimensions(1); 1061 1062 auto inferred_status = 1063 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1064 ASSERT_FALSE(inferred_status.ok()); 1065 ASSERT_THAT(inferred_status.status().error_message(), 1066 HasSubstr("batch dimension numbers must precede non-batch")); 1067 } 1068 1069 // BatchMatMul with out-of-range dimension numbers fails. 1070 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { 1071 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); 1072 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); 1073 1074 DotDimensionNumbers dot_dnums; 1075 dot_dnums.add_lhs_contracting_dimensions(3); 1076 dot_dnums.add_lhs_batch_dimensions(0); 1077 1078 dot_dnums.add_rhs_contracting_dimensions(0); 1079 dot_dnums.add_rhs_batch_dimensions(1); 1080 1081 auto inferred_status = 1082 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1083 ASSERT_FALSE(inferred_status.ok()); 1084 ASSERT_THAT(inferred_status.status().error_message(), 1085 HasSubstr("A dimension number is out of range")); 1086 } 1087 1088 // BatchMatMul with non-unique dimension numbers fails. 1089 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { 1090 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); 1091 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); 1092 1093 DotDimensionNumbers dot_dnums; 1094 dot_dnums.add_lhs_contracting_dimensions(0); 1095 dot_dnums.add_lhs_batch_dimensions(0); 1096 1097 dot_dnums.add_rhs_contracting_dimensions(0); 1098 dot_dnums.add_rhs_batch_dimensions(1); 1099 1100 auto inferred_status = 1101 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); 1102 ASSERT_FALSE(inferred_status.ok()); 1103 ASSERT_THAT(inferred_status.status().error_message(), 1104 HasSubstr("A dimension number is not unique")); 1105 } 1106 1107 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { 1108 // Test variations of broadcasting a vector for a binary add with a 1109 // matrix. 1110 const Shape mat = ShapeUtil::MakeShape(F32, {16, 8}); 1111 const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); 1112 const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); 1113 1114 auto inferred_status_match = ShapeInference::InferBinaryOpShape( 1115 BinaryOperation::BINOP_ADD, mat, vec8, {1}); 1116 ASSERT_IS_OK(inferred_status_match.status()); 1117 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); 1118 1119 auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( 1120 BinaryOperation::BINOP_ADD, mat, vec8, {0}); 1121 ASSERT_FALSE(inferred_status_mismatch.ok()); 1122 1123 inferred_status_match = ShapeInference::InferBinaryOpShape( 1124 BinaryOperation::BINOP_ADD, mat, vec16, {0}); 1125 ASSERT_IS_OK(inferred_status_match.status()); 1126 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); 1127 1128 inferred_status_mismatch = ShapeInference::InferBinaryOpShape( 1129 BinaryOperation::BINOP_ADD, mat, vec16, {1}); 1130 ASSERT_FALSE(inferred_status_mismatch.ok()); 1131 } 1132 1133 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { 1134 // Test variations of broadcasting a matrix for a binary add with a cube. 1135 const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4}); 1136 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); 1137 const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4}); 1138 const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); 1139 1140 auto inferred_status_match = ShapeInference::InferBinaryOpShape( 1141 BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); 1142 ASSERT_IS_OK(inferred_status_match.status()); 1143 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); 1144 1145 inferred_status_match = ShapeInference::InferBinaryOpShape( 1146 BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); 1147 ASSERT_IS_OK(inferred_status_match.status()); 1148 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); 1149 1150 inferred_status_match = ShapeInference::InferBinaryOpShape( 1151 BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); 1152 ASSERT_IS_OK(inferred_status_match.status()); 1153 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); 1154 } 1155 1156 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { 1157 // Test various errors with the broadcast argument. 1158 const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4}); 1159 const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8}); 1160 const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); 1161 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); 1162 const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); 1163 1164 // "magical" broadcast rejected 1165 auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( 1166 BinaryOperation::BINOP_ADD, tensor, vec8, {}); 1167 ASSERT_FALSE(inferred_status_error1.ok()); 1168 ASSERT_THAT(inferred_status_error1.status().error_message(), 1169 HasSubstr("automatic")); 1170 1171 // broadcast_dimension out of bounds for tensor's rank 1172 auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( 1173 BinaryOperation::BINOP_ADD, tensor, vec8, {3}); 1174 ASSERT_FALSE(inferred_status_error2.ok()); 1175 ASSERT_THAT(inferred_status_error2.status().error_message(), 1176 ContainsRegex("broadcast dimension number .* too large")); 1177 1178 // broadcast_dimension doesn't match corresponding dimension 1179 auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( 1180 BinaryOperation::BINOP_ADD, tensor, vec8, {0}); 1181 ASSERT_FALSE(inferred_status_error3.ok()); 1182 ASSERT_THAT(inferred_status_error3.status().error_message(), 1183 HasSubstr("broadcast dimension 0 mismatch")); 1184 1185 // broadcast_dimensions list too long 1186 auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( 1187 BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); 1188 ASSERT_FALSE(inferred_status_error4.ok()); 1189 ASSERT_THAT(inferred_status_error4.status().error_message(), 1190 HasSubstr("size of broadcast_dimensions has to match")); 1191 1192 // there's a dimension above the rank of the tensor 1193 auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( 1194 BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); 1195 ASSERT_FALSE(inferred_status_error5.ok()); 1196 ASSERT_THAT(inferred_status_error5.status().error_message(), 1197 ContainsRegex("broadcast dimension number .* too large")); 1198 1199 // broadcasting dimensions don't match in this order 1200 auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( 1201 BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); 1202 ASSERT_FALSE(inferred_status_error6.ok()); 1203 ASSERT_THAT(inferred_status_error6.status().error_message(), 1204 HasSubstr("broadcast dimension 0 mismatch")); 1205 1206 // The following two tests make sure that broadcasting dimensions are listed 1207 // in a proper (strictly increasing) order, even if the lower-rank array 1208 // matches the higher-rank array in many different ways. 1209 auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( 1210 BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); 1211 ASSERT_FALSE(inferred_status_error7.ok()); 1212 ASSERT_THAT(inferred_status_error7.status().error_message(), 1213 HasSubstr("broadcast dimensions order is wrong")); 1214 1215 auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( 1216 BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); 1217 ASSERT_FALSE(inferred_status_error8.ok()); 1218 ASSERT_THAT(inferred_status_error8.status().error_message(), 1219 HasSubstr("broadcast dimensions order is wrong")); 1220 } 1221 1222 // Tests for the while instruction with proper shapes. 1223 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { 1224 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); 1225 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); 1226 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); 1227 auto inferred_status = 1228 ShapeInference::InferWhileShape(cond, body, result_shape); 1229 ASSERT_IS_OK(inferred_status.status()); 1230 Shape inferred = inferred_status.ValueOrDie(); 1231 ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred)); 1232 } 1233 1234 // Tests for the while instruction with wrong shapes. 1235 TEST_F(ShapeInferenceTest, WhileWithBadShapes) { 1236 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); 1237 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); 1238 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); 1239 1240 auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_); 1241 auto inferred_status_error1 = 1242 ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); 1243 ASSERT_FALSE(inferred_status_error1.ok()); 1244 ASSERT_THAT(inferred_status_error1.status().error_message(), 1245 HasSubstr("condition must take 1 arguments")); 1246 1247 auto bad_shape_2 = 1248 ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); 1249 auto inferred_status_error2 = 1250 ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); 1251 ASSERT_FALSE(inferred_status_error2.ok()); 1252 ASSERT_THAT(inferred_status_error2.status().error_message(), 1253 HasSubstr("body must take 1 arguments")); 1254 1255 auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); 1256 auto inferred_status_error3 = 1257 ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); 1258 ASSERT_FALSE(inferred_status_error3.ok()); 1259 ASSERT_THAT(inferred_status_error3.status().error_message(), 1260 HasSubstr("condition must return a boolean")); 1261 1262 auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); 1263 auto inferred_status_error4 = 1264 ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); 1265 ASSERT_FALSE(inferred_status_error4.ok()); 1266 ASSERT_THAT(inferred_status_error4.status().error_message(), 1267 HasSubstr("parameter of condition and body")); 1268 } 1269 1270 // Tests for the concatenate instruction with proper shapes. 1271 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { 1272 auto inferred_status_1 = ShapeInference::InferConcatOpShape( 1273 {&vector_32_, &vector_64_}, /*dimension=*/0); 1274 ASSERT_IS_OK(inferred_status_1.status()); 1275 Shape inferred_1 = inferred_status_1.ValueOrDie(); 1276 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1)); 1277 1278 auto inferred_status_2 = ShapeInference::InferConcatOpShape( 1279 {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); 1280 ASSERT_IS_OK(inferred_status_2.status()); 1281 Shape inferred_2 = inferred_status_2.ValueOrDie(); 1282 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2)); 1283 1284 auto inferred_status_3 = ShapeInference::InferConcatOpShape( 1285 {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); 1286 ASSERT_IS_OK(inferred_status_3.status()); 1287 Shape inferred_3 = inferred_status_3.ValueOrDie(); 1288 ASSERT_TRUE( 1289 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3)); 1290 } 1291 1292 // Tests for the concatenate instruction with wrong shapes. 1293 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { 1294 auto inferred_status_error1 = 1295 ShapeInference::InferConcatOpShape({}, /*dimension=*/0); 1296 ASSERT_FALSE(inferred_status_error1.ok()); 1297 ASSERT_THAT(inferred_status_error1.status().error_message(), 1298 HasSubstr("Concatenate expects at least one argument")); 1299 1300 auto inferred_status_error2 = 1301 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); 1302 ASSERT_FALSE(inferred_status_error2.ok()); 1303 ASSERT_THAT(inferred_status_error2.status().error_message(), 1304 HasSubstr("dimension to concatenate along out of bounds: -1")); 1305 1306 auto inferred_status_error3 = 1307 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); 1308 ASSERT_FALSE(inferred_status_error3.ok()); 1309 ASSERT_THAT(inferred_status_error3.status().error_message(), 1310 HasSubstr("dimension to concatenate along out of bounds: 1")); 1311 1312 Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); 1313 auto inferred_status_error4 = ShapeInference::InferConcatOpShape( 1314 {&vector_32_, &tuple}, /*dimension=*/0); 1315 ASSERT_FALSE(inferred_status_error4.ok()); 1316 ASSERT_THAT( 1317 inferred_status_error4.status().error_message(), 1318 HasSubstr("Expected non-tuple argument for operand of concatenation.")); 1319 1320 const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); 1321 auto inferred_status_error5 = ShapeInference::InferConcatOpShape( 1322 {&vector_32_, &vector_s32}, /*dimension=*/0); 1323 ASSERT_FALSE(inferred_status_error5.ok()); 1324 ASSERT_THAT( 1325 inferred_status_error5.status().error_message(), 1326 HasSubstr("cannot concatenate arrays with different element types")); 1327 1328 auto inferred_status_error6 = ShapeInference::InferConcatOpShape( 1329 {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); 1330 ASSERT_FALSE(inferred_status_error6.ok()); 1331 ASSERT_THAT(inferred_status_error6.status().error_message(), 1332 HasSubstr("cannot concatenate arrays that differ in " 1333 "dimensions other than the one being " 1334 "concatenated")); 1335 } 1336 1337 TEST_F(ShapeInferenceTest, Pad) { 1338 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); 1339 Shape padding_value_shape = ShapeUtil::MakeShape(F32, {}); 1340 // Padding for dimension 0: {low: 0, high: 2, interior: 3} 1341 // Padding for dimension 1: {low: 1, high: 5, interior: 0} 1342 PaddingConfig padding_config; 1343 auto dimension0 = padding_config.add_dimensions(); 1344 dimension0->set_edge_padding_low(0); 1345 dimension0->set_edge_padding_high(2); 1346 dimension0->set_interior_padding(3); 1347 auto dimension1 = padding_config.add_dimensions(); 1348 dimension1->set_edge_padding_low(1); 1349 dimension1->set_edge_padding_high(5); 1350 dimension1->set_interior_padding(0); 1351 1352 auto inferred_status = ShapeInference::InferPadShape( 1353 input_shape, padding_value_shape, padding_config); 1354 ASSERT_IS_OK(inferred_status.status()); 1355 Shape inferred_shape = inferred_status.ValueOrDie(); 1356 ASSERT_TRUE( 1357 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); 1358 } 1359 1360 TEST_F(ShapeInferenceTest, Reverse) { 1361 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); 1362 1363 auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1}); 1364 ASSERT_IS_OK(inferred_status.status()); 1365 Shape inferred_shape = inferred_status.ValueOrDie(); 1366 ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape)); 1367 } 1368 1369 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { 1370 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); 1371 1372 auto inferred_status_error0 = 1373 ShapeInference::InferReverseShape(input_shape, {0, 2}); 1374 ASSERT_FALSE(inferred_status_error0.ok()); 1375 ASSERT_THAT(inferred_status_error0.status().error_message(), 1376 HasSubstr("out-of-bounds")); 1377 1378 auto inferred_status_error1 = 1379 ShapeInference::InferReverseShape(input_shape, {0, -1}); 1380 ASSERT_FALSE(inferred_status_error1.ok()); 1381 ASSERT_THAT(inferred_status_error1.status().error_message(), 1382 HasSubstr("out-of-bounds")); 1383 1384 auto inferred_status_error2 = 1385 ShapeInference::InferReverseShape(input_shape, {0, 0}); 1386 ASSERT_FALSE(inferred_status_error2.ok()); 1387 ASSERT_THAT(inferred_status_error2.status().error_message(), 1388 HasSubstr("duplicated")); 1389 1390 Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); 1391 auto inferred_status_error3 = 1392 ShapeInference::InferReverseShape(tuple_shape, {0}); 1393 ASSERT_FALSE(inferred_status_error3.ok()); 1394 ASSERT_THAT(inferred_status_error3.status().error_message(), 1395 HasSubstr("Expected non-tuple argument")); 1396 } 1397 1398 TEST_F(ShapeInferenceTest, Call) { 1399 auto inferred_status0 = 1400 ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_)); 1401 EXPECT_IS_OK(inferred_status0.status()); 1402 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); 1403 1404 auto inferred_status1 = ShapeInference::InferCallShape( 1405 {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_}, 1406 ShapeUtil::MakeProgramShape( 1407 {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_)); 1408 EXPECT_IS_OK(inferred_status1.status()); 1409 EXPECT_TRUE( 1410 ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie())); 1411 1412 auto inferred_status_error0 = ShapeInference::InferCallShape( 1413 {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); 1414 EXPECT_FALSE(inferred_status_error0.ok()); 1415 EXPECT_THAT(inferred_status_error0.status().error_message(), 1416 HasSubstr("arity must match")); 1417 1418 auto inferred_status_error1 = ShapeInference::InferCallShape( 1419 {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); 1420 EXPECT_FALSE(inferred_status_error1.ok()); 1421 EXPECT_THAT(inferred_status_error1.status().error_message(), 1422 HasSubstr("arity must match")); 1423 1424 auto inferred_status_error2 = ShapeInference::InferCallShape( 1425 {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); 1426 EXPECT_FALSE(inferred_status_error2.ok()); 1427 EXPECT_THAT(inferred_status_error2.status().error_message(), 1428 HasSubstr("parameter must match argument")); 1429 } 1430 1431 TEST_F(ShapeInferenceTest, Transpose) { 1432 Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); 1433 auto inferred_shape_and_status = 1434 ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0}); 1435 EXPECT_IS_OK(inferred_shape_and_status); 1436 Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); 1437 EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape, 1438 ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); 1439 } 1440 1441 TEST_F(ShapeInferenceTest, Conditional) { 1442 auto inferred_status0 = ShapeInference::InferConditionalShape( 1443 pred_, vector_32_, vector_64_, 1444 ShapeUtil::MakeProgramShape({vector_32_}, f32_), 1445 ShapeUtil::MakeProgramShape({vector_64_}, f32_)); 1446 EXPECT_IS_OK(inferred_status0.status()); 1447 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); 1448 1449 auto inferred_status1 = ShapeInference::InferConditionalShape( 1450 pred_, matrix_32_48_, vector_32_, 1451 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), 1452 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); 1453 EXPECT_IS_OK(inferred_status1.status()); 1454 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); 1455 1456 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); 1457 auto inferred_status2 = ShapeInference::InferConditionalShape( 1458 pred_, matrix_32_48_, tuple_f32_v32, 1459 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), 1460 ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); 1461 EXPECT_IS_OK(inferred_status2.status()); 1462 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); 1463 1464 auto inferred_status_error0 = ShapeInference::InferConditionalShape( 1465 s32_, vector_32_, vector_64_, 1466 ShapeUtil::MakeProgramShape({vector_32_}, f32_), 1467 ShapeUtil::MakeProgramShape({vector_64_}, f32_)); 1468 EXPECT_FALSE(inferred_status_error0.ok()); 1469 EXPECT_THAT(inferred_status_error0.status().error_message(), 1470 HasSubstr("predicate must be a boolean")); 1471 1472 auto inferred_status_error1 = ShapeInference::InferConditionalShape( 1473 pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, 1474 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), 1475 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); 1476 EXPECT_FALSE(inferred_status_error1.ok()); 1477 EXPECT_THAT(inferred_status_error1.status().error_message(), 1478 HasSubstr("true_computation must take 1 argument")); 1479 1480 auto inferred_status_error2 = ShapeInference::InferConditionalShape( 1481 pred_, vector_32_, vector_64_, 1482 ShapeUtil::MakeProgramShape({vector_64_}, f32_), 1483 ShapeUtil::MakeProgramShape({vector_64_}, f32_)); 1484 EXPECT_FALSE(inferred_status_error2.ok()); 1485 EXPECT_THAT(inferred_status_error2.status().error_message(), 1486 HasSubstr("true_operand must match the shape of the only " 1487 "parameter of true_computation")); 1488 1489 auto inferred_status_error3 = ShapeInference::InferConditionalShape( 1490 pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), 1491 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), 1492 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); 1493 EXPECT_FALSE(inferred_status_error3.ok()); 1494 EXPECT_THAT(inferred_status_error3.status().error_message(), 1495 HasSubstr("false_computation must take 1 argument")); 1496 1497 auto inferred_status_error4 = ShapeInference::InferConditionalShape( 1498 pred_, vector_32_, vector_64_, 1499 ShapeUtil::MakeProgramShape({vector_32_}, f32_), 1500 ShapeUtil::MakeProgramShape({vector_32_}, f32_)); 1501 EXPECT_FALSE(inferred_status_error4.ok()); 1502 EXPECT_THAT(inferred_status_error4.status().error_message(), 1503 HasSubstr("false_operand must match the shape of the only " 1504 "parameter of false_computation")); 1505 1506 auto inferred_status_error5 = ShapeInference::InferConditionalShape( 1507 pred_, vector_32_, vector_64_, 1508 ShapeUtil::MakeProgramShape({vector_32_}, f32_), 1509 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); 1510 EXPECT_FALSE(inferred_status_error5.ok()); 1511 EXPECT_THAT(inferred_status_error5.status().error_message(), 1512 HasSubstr("the result of true_computation and false_computation " 1513 "must have the same shape")); 1514 } 1515 1516 TEST_F(ShapeInferenceTest, BadSlice) { 1517 auto arg = ShapeUtil::MakeShape(F32, {4}); 1518 StatusOr<Shape> statusor = 1519 ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); 1520 ASSERT_FALSE(statusor.ok()); 1521 1522 LOG(INFO) << statusor.status(); 1523 1524 EXPECT_THAT(statusor.status().error_message(), 1525 HasSubstr("less than or equal to dimension size")) 1526 << statusor.status(); 1527 EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape")) 1528 << statusor.status(); 1529 } 1530 1531 class GatherShapeInferenceTest : public ShapeInferenceTest { 1532 protected: 1533 const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); 1534 const Shape s64_4d_tensor_10_9_8_7_1_ = 1535 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); 1536 const Shape s64_4d_tensor_10_9_8_7_5_ = 1537 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); 1538 const Shape f32_5d_tensor_50_49_48_47_46_ = 1539 ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); 1540 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( 1541 {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); 1542 }; 1543 1544 TEST_F(GatherShapeInferenceTest, TensorFlowGather) { 1545 TF_ASSERT_OK_AND_ASSIGN( 1546 Shape gather_shape, 1547 ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, 1548 HloInstruction::MakeGatherDimNumbers( 1549 /*output_window_dims=*/{0}, 1550 /*elided_window_dims=*/{1}, 1551 /*gather_dims_to_operand_dims=*/{1}), 1552 /*window_bounds=*/{64, 1})); 1553 EXPECT_TRUE( 1554 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) 1555 << ShapeUtil::HumanString(gather_shape); 1556 } 1557 1558 TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { 1559 TF_ASSERT_OK_AND_ASSIGN( 1560 Shape gather_shape, 1561 ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, 1562 HloInstruction::MakeGatherDimNumbers( 1563 /*output_window_dims=*/{1}, 1564 /*elided_window_dims=*/{0}, 1565 /*gather_dims_to_operand_dims=*/{0}), 1566 /*window_bounds=*/{1, 48})); 1567 EXPECT_TRUE( 1568 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) 1569 << ShapeUtil::HumanString(gather_shape); 1570 } 1571 1572 TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { 1573 TF_ASSERT_OK_AND_ASSIGN( 1574 Shape gather_shape, 1575 ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, 1576 HloInstruction::MakeGatherDimNumbers( 1577 /*output_window_dims=*/{4}, 1578 /*elided_window_dims=*/{0}, 1579 /*gather_dims_to_operand_dims=*/{0}), 1580 /*window_bounds=*/{1, 48})); 1581 EXPECT_TRUE(ShapeUtil::Equal(gather_shape, 1582 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) 1583 << ShapeUtil::HumanString(gather_shape); 1584 } 1585 1586 TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { 1587 TF_ASSERT_OK_AND_ASSIGN( 1588 Shape gather_shape, 1589 ShapeInference::InferGatherShape( 1590 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1591 HloInstruction::MakeGatherDimNumbers( 1592 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1593 /*elided_window_dims=*/{}, 1594 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1595 /*window_bounds=*/{30, 29, 28, 27, 26})); 1596 EXPECT_TRUE(ShapeUtil::Equal( 1597 gather_shape, 1598 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) 1599 << ShapeUtil::HumanString(gather_shape); 1600 } 1601 1602 TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { 1603 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1604 tuple_shape_, s64_vector_32_, 1605 HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, 1606 /*elided_window_dims=*/{1}, 1607 /*gather_dims_to_operand_dims=*/{1}), 1608 /*window_bounds=*/{64, 1}); 1609 ASSERT_FALSE(statusor.ok()); 1610 EXPECT_THAT(statusor.status().error_message(), 1611 HasSubstr("Expected non-tuple argument for input")) 1612 << statusor.status(); 1613 } 1614 1615 TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { 1616 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1617 s64_vector_32_, tuple_shape_, 1618 HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, 1619 /*elided_window_dims=*/{1}, 1620 /*gather_dims_to_operand_dims=*/{1}), 1621 /*window_bounds=*/{64, 1}); 1622 ASSERT_FALSE(statusor.ok()); 1623 EXPECT_THAT(statusor.status().error_message(), 1624 HasSubstr("Expected non-tuple argument for gather indices")) 1625 << statusor.status(); 1626 } 1627 1628 TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { 1629 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1630 s64_vector_32_, s32_, 1631 HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, 1632 /*elided_window_dims=*/{1}, 1633 /*gather_dims_to_operand_dims=*/{1}), 1634 /*window_bounds=*/{64, 1}); 1635 ASSERT_FALSE(statusor.ok()); 1636 EXPECT_THAT(statusor.status().error_message(), 1637 HasSubstr("Gather indices parameter must at least of rank 1")) 1638 << statusor.status(); 1639 } 1640 1641 TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { 1642 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1643 s64_vector_32_, vector_32_, 1644 HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, 1645 /*elided_window_dims=*/{1}, 1646 /*gather_dims_to_operand_dims=*/{1}), 1647 /*window_bounds=*/{64, 1}); 1648 ASSERT_FALSE(statusor.ok()); 1649 EXPECT_THAT(statusor.status().error_message(), 1650 HasSubstr("Gather indices parameter must be an integral tensor")) 1651 << statusor.status(); 1652 } 1653 1654 TEST_F(GatherShapeInferenceTest, 1655 InvalidGatherDimNumbers_NonAscendingWindowIndices) { 1656 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1657 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1658 HloInstruction::MakeGatherDimNumbers( 1659 /*output_window_dims=*/{4, 5, 6, 8, 7}, 1660 /*elided_window_dims=*/{}, 1661 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1662 /*window_bounds=*/{30, 29, 28, 27, 26}); 1663 ASSERT_FALSE(statusor.ok()); 1664 EXPECT_THAT( 1665 statusor.status().error_message(), 1666 HasSubstr("Output window dimensions in gather op must be ascending")) 1667 << statusor.status(); 1668 } 1669 1670 TEST_F(GatherShapeInferenceTest, 1671 InvalidGatherDimNumbers_RepeatedWindowIndices) { 1672 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1673 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1674 HloInstruction::MakeGatherDimNumbers( 1675 /*output_window_dims=*/{4, 5, 6, 7, 7}, 1676 /*elided_window_dims=*/{}, 1677 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1678 /*window_bounds=*/{30, 29, 28, 27, 26}); 1679 ASSERT_FALSE(statusor.ok()); 1680 EXPECT_THAT( 1681 statusor.status().error_message(), 1682 HasSubstr("Output window dimensions in gather op must not repeat")) 1683 << statusor.status(); 1684 } 1685 1686 TEST_F(GatherShapeInferenceTest, 1687 InvalidGatherDimNumbers_WindowIndexOutOfBounds) { 1688 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1689 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1690 HloInstruction::MakeGatherDimNumbers( 1691 /*output_window_dims=*/{4, 5, 99, 100, 101}, 1692 /*elided_window_dims=*/{}, 1693 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1694 /*window_bounds=*/{30, 29, 28, 27, 26}); 1695 ASSERT_FALSE(statusor.ok()); 1696 EXPECT_THAT(statusor.status().error_message(), 1697 HasSubstr("Window index 2 in gather op is out of bounds")) 1698 << statusor.status(); 1699 } 1700 1701 TEST_F(GatherShapeInferenceTest, 1702 InvalidGatherDimNumbers_MismatchingElidedWindowDims) { 1703 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1704 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1705 HloInstruction::MakeGatherDimNumbers( 1706 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1707 /*elided_window_dims=*/{4}, 1708 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1709 /*window_bounds=*/{30, 29, 28, 27, 26}); 1710 ASSERT_FALSE(statusor.ok()); 1711 EXPECT_THAT( 1712 statusor.status().error_message(), 1713 HasSubstr("All components of the window index in a gather op must either " 1714 "be a output window index or explicitly elided")) 1715 << statusor.status(); 1716 } 1717 1718 TEST_F(GatherShapeInferenceTest, 1719 InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { 1720 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1721 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1722 HloInstruction::MakeGatherDimNumbers( 1723 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1724 /*elided_window_dims=*/{0, 1, 2, 3, 19}, 1725 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1726 /*window_bounds=*/{30, 29, 28, 27, 26}); 1727 ASSERT_FALSE(statusor.ok()); 1728 EXPECT_THAT(statusor.status().error_message(), 1729 HasSubstr("Invalid elided_window_dims set in gather op; valid " 1730 "range is [0, 5), got: 19")) 1731 << statusor.status(); 1732 } 1733 1734 TEST_F(GatherShapeInferenceTest, 1735 InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { 1736 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1737 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1738 HloInstruction::MakeGatherDimNumbers( 1739 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1740 /*elided_window_dims=*/{0, 1, 2, 3, 3}, 1741 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1742 /*window_bounds=*/{30, 29, 28, 27, 26}); 1743 ASSERT_FALSE(statusor.ok()); 1744 EXPECT_THAT( 1745 statusor.status().error_message(), 1746 HasSubstr( 1747 "Repeated dimensions not allowed in elided_window_dims in gather op")) 1748 << statusor.status(); 1749 } 1750 1751 TEST_F(GatherShapeInferenceTest, 1752 InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { 1753 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1754 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1755 HloInstruction::MakeGatherDimNumbers( 1756 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1757 /*elided_window_dims=*/{}, 1758 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), 1759 /*window_bounds=*/{30, 29, 28, 27, 26}); 1760 ASSERT_FALSE(statusor.ok()); 1761 EXPECT_THAT( 1762 statusor.status().error_message(), 1763 HasSubstr( 1764 "There must be exactly as many elements in " 1765 "gather_dims_to_operand_dims " 1766 "as there are elements in the last dimension of %gather_indices")) 1767 << statusor.status(); 1768 } 1769 1770 TEST_F(GatherShapeInferenceTest, 1771 InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { 1772 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1773 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1774 HloInstruction::MakeGatherDimNumbers( 1775 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1776 /*elided_window_dims=*/{}, 1777 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), 1778 /*window_bounds=*/{30, 29, 28, 27, 26}); 1779 ASSERT_FALSE(statusor.ok()); 1780 EXPECT_THAT( 1781 statusor.status().error_message(), 1782 HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " 1783 "[0, 5), got: 4->7")) 1784 << statusor.status(); 1785 } 1786 1787 TEST_F(GatherShapeInferenceTest, 1788 InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { 1789 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1790 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1791 HloInstruction::MakeGatherDimNumbers( 1792 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1793 /*elided_window_dims=*/{}, 1794 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), 1795 /*window_bounds=*/{30, 29, 28, 27, 26}); 1796 ASSERT_FALSE(statusor.ok()); 1797 EXPECT_THAT( 1798 statusor.status().error_message(), 1799 HasSubstr( 1800 "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) 1801 << statusor.status(); 1802 } 1803 1804 TEST_F(GatherShapeInferenceTest, 1805 InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { 1806 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1807 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1808 HloInstruction::MakeGatherDimNumbers( 1809 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1810 /*elided_window_dims=*/{2, 1}, 1811 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1812 /*window_bounds=*/{1, 1, 28, 27, 26}); 1813 ASSERT_FALSE(statusor.ok()); 1814 EXPECT_THAT(statusor.status().error_message(), 1815 HasSubstr("elided_window_dims in gather op must be sorted")) 1816 << statusor.status(); 1817 } 1818 1819 TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { 1820 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1821 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1822 HloInstruction::MakeGatherDimNumbers( 1823 /*output_window_dims=*/{4, 5, 6, 7}, 1824 /*elided_window_dims=*/{2}, 1825 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1826 /*window_bounds=*/{30, 29, 1, 300, 26}); 1827 ASSERT_FALSE(statusor.ok()); 1828 EXPECT_THAT(statusor.status().error_message(), 1829 HasSubstr("Window bound at index 3 in gather op is out of range, " 1830 "must be within [0, 48), got 300")) 1831 << statusor.status(); 1832 } 1833 1834 TEST_F(GatherShapeInferenceTest, 1835 InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { 1836 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1837 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1838 HloInstruction::MakeGatherDimNumbers( 1839 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1840 /*elided_window_dims=*/{}, 1841 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1842 /*window_bounds=*/{30, 29, 28, 26}); 1843 ASSERT_FALSE(statusor.ok()); 1844 EXPECT_THAT( 1845 statusor.status().error_message(), 1846 HasSubstr( 1847 "Gather op must have one window bound for every input dimension")) 1848 << statusor.status(); 1849 } 1850 1851 TEST_F(GatherShapeInferenceTest, 1852 InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { 1853 StatusOr<Shape> statusor = ShapeInference::InferGatherShape( 1854 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, 1855 HloInstruction::MakeGatherDimNumbers( 1856 /*output_window_dims=*/{4, 5, 6, 7}, 1857 /*elided_window_dims=*/{1}, 1858 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1859 /*window_bounds=*/{30, 29, 28, 26, 20}); 1860 ASSERT_FALSE(statusor.ok()); 1861 EXPECT_THAT(statusor.status().error_message(), 1862 HasSubstr("Gather op can only elide window indices with bound 1, " 1863 "but bound is 29 for index 1 at position 0")) 1864 << statusor.status(); 1865 } 1866 1867 } // namespace 1868 } // namespace xla 1869