1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/shape_inference.h" 16 17 #include "tensorflow/core/framework/fake_input.h" 18 #include "tensorflow/core/framework/node_def_builder.h" 19 #include "tensorflow/core/framework/op_def_builder.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/framework/types.pb.h" 23 #include "tensorflow/core/lib/core/status_test_util.h" 24 #include "tensorflow/core/lib/strings/strcat.h" 25 #include "tensorflow/core/platform/test.h" 26 27 namespace tensorflow { 28 namespace shape_inference { 29 namespace { 30 31 OpDef MakeOpDefWithLists() { 32 OpRegistrationData op_reg_data; 33 OpDefBuilder b("dummy"); 34 b.Input(strings::StrCat("input: N * float")); 35 b.Output(strings::StrCat("output: N * float")); 36 CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok()); 37 return op_reg_data.op_def; 38 } 39 40 PartialTensorShape S(std::initializer_list<int64> dims) { 41 return PartialTensorShape(dims); 42 } 43 44 PartialTensorShape Unknown() { return PartialTensorShape(); } 45 46 } // namespace 47 48 class ShapeInferenceTest : public ::testing::Test { 49 protected: 50 // These give access to private functions of DimensionHandle and ShapeHandle. 51 bool SameHandle(DimensionHandle a, DimensionHandle b) { 52 return a.SameHandle(b); 53 } 54 bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); } 55 bool IsSet(DimensionHandle d) { return d.IsSet(); } 56 bool IsSet(ShapeHandle s) { return s.IsSet(); } 57 void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1, 58 DimensionHandle* out) { 59 c->Relax(d0, d1, out); 60 } 61 void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1, 62 ShapeHandle* out) { 63 c->Relax(s0, s1, out); 64 } 65 void TestMergeHandles(bool input_not_output); 66 void TestRelaxHandles(bool input_not_output); 67 68 static const int kVersion = 0; // used for graph-def version. 69 }; 70 71 TEST_F(ShapeInferenceTest, InputOutputByName) { 72 // Setup test to contain an input tensor list of size 3. 73 OpDef op_def = MakeOpDefWithLists(); 74 NodeDef def; 75 auto s = NodeDefBuilder("dummy", &op_def) 76 .Attr("N", 3) 77 .Input(FakeInput(DT_FLOAT)) 78 .Finalize(&def); 79 InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, 80 {}, {}, {}); 81 82 EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); 83 EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1)))); 84 EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2)))); 85 // Test getters. 86 std::vector<ShapeHandle> shapes; 87 EXPECT_FALSE(c.input("nonexistent", &shapes).ok()); 88 TF_EXPECT_OK(c.input("input", &shapes)); 89 EXPECT_EQ("[1,5]", c.DebugString(shapes[0])); 90 EXPECT_EQ("[2,5]", c.DebugString(shapes[1])); 91 EXPECT_EQ("[1,3]", c.DebugString(shapes[2])); 92 93 // Test setters. 94 EXPECT_FALSE(c.set_output("nonexistent", shapes).ok()); 95 TF_EXPECT_OK(c.set_output("output", shapes)); 96 EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0)))); 97 EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1)))); 98 EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2)))); 99 } 100 101 static OpDef MakeOpDef(int num_inputs, int num_outputs) { 102 OpRegistrationData op_reg_data; 103 OpDefBuilder b("dummy"); 104 for (int i = 0; i < num_inputs; ++i) { 105 b.Input(strings::StrCat("i", i, ": float")); 106 } 107 for (int i = 0; i < num_outputs; ++i) { 108 b.Output(strings::StrCat("o", i, ": float")); 109 } 110 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); 111 return op_reg_data.op_def; 112 } 113 114 TEST_F(ShapeInferenceTest, DimensionOrConstant) { 115 NodeDef def; 116 InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}); 117 EXPECT_EQ(InferenceContext::kUnknownDim, 118 c.Value(InferenceContext::kUnknownDim)); 119 EXPECT_EQ(1, c.Value(1)); 120 121 #ifndef NDEBUG 122 // Only run death test if DCHECKS are enabled. 123 EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to"); 124 #endif 125 } 126 127 TEST_F(ShapeInferenceTest, Run) { 128 NodeDef def; 129 def.set_name("foo"); 130 def.set_op("foo_op"); 131 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}); 132 TF_ASSERT_OK(c.construction_status()); 133 134 { 135 auto fn = [](InferenceContext* c) { 136 ShapeHandle h; 137 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h)); 138 c->set_output(0, c->input(0)); 139 c->set_output(1, c->input(0)); 140 return Status::OK(); 141 }; 142 TF_ASSERT_OK(c.Run(fn)); 143 } 144 145 { 146 auto fn = [](InferenceContext* c) { 147 ShapeHandle h; 148 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 149 c->set_output(0, c->input(0)); 150 c->set_output(1, c->input(0)); 151 return Status::OK(); 152 }; 153 Status s = c.Run(fn); 154 // Extra error message is attached when Run fails. 155 EXPECT_TRUE(StringPiece(s.ToString()) 156 .contains("Shape must be at most rank 0 but " 157 "is rank 1 for 'foo' (op: " 158 "'foo_op')")) 159 << s; 160 } 161 } 162 163 // Tests different context data added when Run returns error. 164 TEST_F(ShapeInferenceTest, AttachContext) { 165 NodeDef def; 166 def.set_name("foo"); 167 def.set_op("foo_op"); 168 // Error when no constant tensors were requested. 169 { 170 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, 171 {}); 172 TF_ASSERT_OK(c.construction_status()); 173 auto fn = [](InferenceContext* c) { 174 ShapeHandle h; 175 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 176 c->set_output(0, c->input(0)); 177 return Status::OK(); 178 }; 179 EXPECT_EQ( 180 "Invalid argument: Shape must be at most rank 0 but is rank 3 for " 181 "'foo' (op: 'foo_op') with input shapes: [1,2,3].", 182 c.Run(fn).ToString()); 183 } 184 185 // Error when a constant tensor value was requested. 186 { 187 Tensor input_t = 188 ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5}); 189 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 190 {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {}); 191 TF_ASSERT_OK(c.construction_status()); 192 auto fn = [](InferenceContext* c) { 193 c->input_tensor(0); // get this one, but it's null - won't be in error. 194 c->input_tensor(1); // get this one, will now be in error. 195 ShapeHandle h; 196 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 197 c->set_output(0, c->input(0)); 198 return Status::OK(); 199 }; 200 EXPECT_EQ( 201 "Invalid argument: Shape must be at most rank 0 but is rank 3 for " 202 "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with " 203 "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.", 204 c.Run(fn).ToString()); 205 } 206 207 // Error when a constant tensor value as shape was requested, but no partial 208 // shapes provided. 209 { 210 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); 211 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, 212 {nullptr, &input_t}, {}, {}); 213 TF_ASSERT_OK(c.construction_status()); 214 auto fn = [](InferenceContext* c) { 215 ShapeHandle s; 216 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 217 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 218 ShapeHandle h; 219 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 220 c->set_output(0, c->input(0)); 221 return Status::OK(); 222 }; 223 EXPECT_EQ( 224 "Invalid argument: Shape must be at most rank 0 but is rank 1 for " 225 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " 226 "input tensors: input[1] = <1 2 3 4 5>.", 227 c.Run(fn).ToString()); 228 } 229 230 // Error when a constant tensor value as shape was requested, and a partial 231 // shape was provided. 232 { 233 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); 234 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, 235 {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}); 236 TF_ASSERT_OK(c.construction_status()); 237 auto fn = [](InferenceContext* c) { 238 ShapeHandle s; 239 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 240 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 241 ShapeHandle h; 242 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 243 c->set_output(0, c->input(0)); 244 return Status::OK(); 245 }; 246 EXPECT_EQ( 247 "Invalid argument: Shape must be at most rank 0 but is rank 1 for " 248 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " 249 "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed " 250 "as partial shapes: input[0] = [10,?,5].", 251 c.Run(fn).ToString()); 252 } 253 } 254 255 TEST_F(ShapeInferenceTest, RankAndDimInspection) { 256 NodeDef def; 257 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 258 {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}); 259 EXPECT_EQ(3, c.num_inputs()); 260 EXPECT_EQ(2, c.num_outputs()); 261 262 auto in0 = c.input(0); 263 EXPECT_EQ("?", c.DebugString(in0)); 264 EXPECT_FALSE(c.RankKnown(in0)); 265 EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0)); 266 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0))); 267 EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1))); 268 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000))); 269 270 auto in1 = c.input(1); 271 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 272 EXPECT_TRUE(c.RankKnown(in1)); 273 EXPECT_EQ(3, c.Rank(in1)); 274 auto d = c.Dim(in1, 0); 275 EXPECT_EQ(1, c.Value(d)); 276 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3))); 277 EXPECT_TRUE(c.ValueKnown(d)); 278 EXPECT_EQ("1", c.DebugString(d)); 279 d = c.Dim(in1, 1); 280 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d)); 281 EXPECT_FALSE(c.ValueKnown(d)); 282 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2))); 283 EXPECT_EQ("?", c.DebugString(d)); 284 d = c.Dim(in1, 2); 285 EXPECT_EQ(3, c.Value(d)); 286 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1))); 287 EXPECT_TRUE(c.ValueKnown(d)); 288 EXPECT_EQ("3", c.DebugString(d)); 289 290 auto in2 = c.input(2); 291 EXPECT_EQ("[]", c.DebugString(in2)); 292 EXPECT_TRUE(c.RankKnown(in2)); 293 EXPECT_EQ(0, c.Rank(in2)); 294 } 295 296 TEST_F(ShapeInferenceTest, NumElements) { 297 NodeDef def; 298 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 299 {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}); 300 301 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0)))); 302 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1)))); 303 304 // Different handles (not the same unknown value). 305 EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1)))); 306 307 EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2)))); 308 } 309 310 TEST_F(ShapeInferenceTest, WithRank) { 311 NodeDef def; 312 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 313 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 314 315 auto in0 = c.input(0); 316 auto in1 = c.input(1); 317 ShapeHandle s1; 318 ShapeHandle s2; 319 320 // WithRank on a shape with unknown dimensionality always succeeds. 321 EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok()); 322 EXPECT_EQ("[?]", c.DebugString(s1)); 323 324 EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok()); 325 EXPECT_EQ("[?,?]", c.DebugString(s2)); 326 EXPECT_FALSE(SameHandle(s1, s2)); 327 EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1))); 328 329 EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok()); 330 EXPECT_EQ("[?]", c.DebugString(s2)); 331 EXPECT_FALSE(SameHandle(s1, s2)); 332 333 EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok()); 334 EXPECT_EQ("[]", c.DebugString(s1)); 335 336 // WithRank on shape with known dimensionality. 337 s1 = in1; 338 EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3", 339 c.WithRank(in1, 2, &s1).ToString()); 340 EXPECT_FALSE(IsSet(s1)); 341 EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok()); 342 EXPECT_TRUE(SameHandle(s1, in1)); 343 344 // Inputs are unchanged. 345 EXPECT_EQ("?", c.DebugString(in0)); 346 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 347 } 348 349 TEST_F(ShapeInferenceTest, WithRankAtMost) { 350 NodeDef def; 351 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 352 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 353 354 auto in0 = c.input(0); 355 auto in1 = c.input(1); 356 ShapeHandle s1; 357 ShapeHandle s2; 358 359 // WithRankAtMost on a shape with unknown dimensionality always succeeds. 360 EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok()); 361 EXPECT_EQ("?", c.DebugString(s1)); 362 EXPECT_TRUE(SameHandle(in0, s1)); 363 364 EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok()); 365 EXPECT_EQ("?", c.DebugString(s2)); 366 EXPECT_TRUE(SameHandle(s1, s2)); 367 368 // WithRankAtMost on shape with known dimensionality. 369 s1 = in1; 370 EXPECT_TRUE( 371 StringPiece(c.WithRankAtMost(in1, 2, &s1).ToString()) 372 .contains( 373 "Invalid argument: Shape must be at most rank 2 but is rank 3")); 374 375 EXPECT_FALSE(IsSet(s1)); 376 EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok()); 377 EXPECT_TRUE(SameHandle(s1, in1)); 378 EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok()); 379 EXPECT_TRUE(SameHandle(s1, in1)); 380 EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok()); 381 EXPECT_TRUE(SameHandle(s1, in1)); 382 383 // Inputs are unchanged. 384 EXPECT_EQ("?", c.DebugString(in0)); 385 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 386 } 387 388 TEST_F(ShapeInferenceTest, WithRankAtLeast) { 389 NodeDef def; 390 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 391 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 392 393 auto in0 = c.input(0); 394 auto in1 = c.input(1); 395 ShapeHandle s1; 396 ShapeHandle s2; 397 398 // WithRankAtLeast on a shape with unknown dimensionality always succeeds. 399 EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok()); 400 EXPECT_EQ("?", c.DebugString(s1)); 401 EXPECT_TRUE(SameHandle(in0, s1)); 402 403 EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok()); 404 EXPECT_EQ("?", c.DebugString(s2)); 405 EXPECT_TRUE(SameHandle(s1, s2)); 406 407 // WithRankAtLeast on shape with known dimensionality. 408 s1 = in1; 409 EXPECT_TRUE( 410 StringPiece(c.WithRankAtLeast(in1, 4, &s1).ToString()) 411 .contains( 412 "Invalid argument: Shape must be at least rank 4 but is rank 3")); 413 414 EXPECT_FALSE(IsSet(s1)); 415 EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok()); 416 EXPECT_TRUE(SameHandle(s1, in1)); 417 EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok()); 418 EXPECT_TRUE(SameHandle(s1, in1)); 419 EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok()); 420 EXPECT_TRUE(SameHandle(s1, in1)); 421 422 // Inputs are unchanged. 423 EXPECT_EQ("?", c.DebugString(in0)); 424 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 425 } 426 427 TEST_F(ShapeInferenceTest, WithValue) { 428 NodeDef def; 429 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}); 430 431 auto d0 = c.Dim(c.input(0), 0); 432 auto d1 = c.Dim(c.input(0), 1); 433 DimensionHandle out1; 434 DimensionHandle out2; 435 436 // WithValue on a dimension with unknown value always succeeds. 437 EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok()); 438 EXPECT_EQ(1, c.Value(out1)); 439 440 EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok()); 441 EXPECT_EQ(2, c.Value(out2)); 442 EXPECT_FALSE(SameHandle(out1, out2)); 443 EXPECT_FALSE(SameHandle(out1, d1)); 444 445 EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok()); 446 EXPECT_EQ(1, c.Value(out2)); 447 EXPECT_FALSE(SameHandle(out1, out2)); 448 449 // WithValue on dimension with known size. 450 out1 = d0; 451 452 EXPECT_TRUE(StringPiece(c.WithValue(d0, 0, &out1).ToString()) 453 .contains("Invalid argument: Dimension must be 0 but is 1")); 454 EXPECT_FALSE(IsSet(out1)); 455 out1 = d0; 456 EXPECT_TRUE(StringPiece(c.WithValue(d0, 2, &out1).ToString()) 457 .contains("Invalid argument: Dimension must be 2 but is 1")); 458 459 EXPECT_FALSE(IsSet(out1)); 460 EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok()); 461 EXPECT_TRUE(SameHandle(d0, out1)); 462 463 // Inputs are unchanged. 464 EXPECT_EQ("1", c.DebugString(d0)); 465 EXPECT_EQ("?", c.DebugString(d1)); 466 } 467 468 TEST_F(ShapeInferenceTest, MergeDim) { 469 NodeDef def; 470 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, 471 {}, {}, {}); 472 473 auto d2 = c.Dim(c.input(0), 0); 474 auto d_unknown = c.Dim(c.input(0), 1); 475 auto d2_b = c.Dim(c.input(0), 2); 476 auto d1 = c.Dim(c.input(0), 3); 477 auto d_unknown_b = c.Dim(c.input(0), 4); 478 DimensionHandle out; 479 480 // Merging anything with unknown returns the same pointer. 481 EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok()); 482 EXPECT_TRUE(SameHandle(d2, out)); 483 EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok()); 484 EXPECT_TRUE(SameHandle(d2, out)); 485 EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok()); 486 EXPECT_TRUE(SameHandle(d_unknown, out)); 487 488 auto merged_dims = c.MergedDims(); 489 ASSERT_EQ(3, merged_dims.size()); 490 EXPECT_TRUE(merged_dims[0].first.SameHandle(d2)); 491 EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown)); 492 EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown)); 493 EXPECT_TRUE(merged_dims[1].second.SameHandle(d2)); 494 EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown)); 495 EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b)); 496 497 // Merging with self is a no-op and returns self. 498 EXPECT_TRUE(c.Merge(d2, d2, &out).ok()); 499 EXPECT_TRUE(SameHandle(d2, out)); 500 EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok()); 501 EXPECT_TRUE(SameHandle(d_unknown, out)); 502 503 merged_dims = c.MergedDims(); 504 EXPECT_EQ(3, merged_dims.size()); 505 506 // Merging equal values is a no op and returns first one. 507 EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok()); 508 EXPECT_TRUE(SameHandle(d2, out)); 509 EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok()); 510 EXPECT_TRUE(SameHandle(d2_b, out)); 511 512 merged_dims = c.MergedDims(); 513 EXPECT_EQ(3, merged_dims.size()); 514 515 // Merging unequal values is an error. 516 EXPECT_TRUE( 517 StringPiece(c.Merge(d2, d1, &out).ToString()) 518 .contains( 519 "Invalid argument: Dimensions must be equal, but are 2 and 1")); 520 521 EXPECT_FALSE(IsSet(out)); 522 EXPECT_TRUE( 523 StringPiece(c.Merge(d1, d2, &out).ToString()) 524 .contains( 525 "Invalid argument: Dimensions must be equal, but are 1 and 2")); 526 527 EXPECT_FALSE(IsSet(out)); 528 529 merged_dims = c.MergedDims(); 530 EXPECT_EQ(3, merged_dims.size()); 531 } 532 533 TEST_F(ShapeInferenceTest, RelaxDim) { 534 NodeDef def; 535 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), 536 {S({2, InferenceContext::kUnknownDim, 2, 1, 537 InferenceContext::kUnknownDim})}, 538 {}, {}, {}); 539 540 auto d2 = c.Dim(c.input(0), 0); 541 auto d_unknown = c.Dim(c.input(0), 1); 542 auto d2_b = c.Dim(c.input(0), 2); 543 auto d1 = c.Dim(c.input(0), 3); 544 auto d_unknown_b = c.Dim(c.input(0), 4); 545 DimensionHandle out; 546 547 // Relaxing anything with unknown returns a new unknown or the existing 548 // unknown. 549 Relax(&c, d2, d_unknown, &out); 550 EXPECT_TRUE(SameHandle(d_unknown, out)); 551 EXPECT_FALSE(SameHandle(d_unknown_b, out)); 552 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 553 Relax(&c, d_unknown, d2, &out); 554 EXPECT_FALSE(SameHandle(d_unknown, out)); 555 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 556 Relax(&c, d_unknown, d_unknown_b, &out); 557 EXPECT_FALSE(SameHandle(d_unknown, out)); 558 EXPECT_TRUE(SameHandle(d_unknown_b, out)); 559 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 560 561 // Relaxing with self returns self. 562 Relax(&c, d2, d2, &out); 563 EXPECT_TRUE(SameHandle(d2, out)); 564 Relax(&c, d_unknown, d_unknown, &out); 565 EXPECT_TRUE(SameHandle(d_unknown, out)); 566 567 // Relaxing equal values returns first one. 568 Relax(&c, d2, d2_b, &out); 569 EXPECT_TRUE(SameHandle(d2, out)); 570 Relax(&c, d2_b, d2, &out); 571 EXPECT_TRUE(SameHandle(d2_b, out)); 572 573 // Relaxing unequal values returns a new unknown. 574 Relax(&c, d2, d1, &out); 575 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 576 Relax(&c, d1, d2, &out); 577 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 578 } 579 580 TEST_F(ShapeInferenceTest, RelaxShape) { 581 NodeDef def; 582 InferenceContext c( 583 kVersion, &def, MakeOpDef(7, 2), 584 {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}), 585 S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})}, 586 {}, {}, {}); 587 588 auto s_unknown = c.input(0); 589 auto s_1_2 = c.input(1); 590 auto s_u_2 = c.input(2); 591 auto s_1_u = c.input(3); 592 auto s_1_3 = c.input(4); 593 auto s_unknown_b = c.input(5); 594 auto s_1 = c.input(6); 595 ShapeHandle out; 596 597 // Relaxing any shape with unknown returns a new unknown. 598 Relax(&c, s_unknown, s_1_2, &out); 599 EXPECT_FALSE(SameHandle(s_u_2, s_unknown)); 600 EXPECT_EQ("?", c.DebugString(out)); 601 Relax(&c, s_u_2, s_unknown, &out); 602 EXPECT_FALSE(SameHandle(s_u_2, out)); 603 EXPECT_EQ("?", c.DebugString(out)); 604 Relax(&c, s_unknown, s_unknown_b, &out); 605 EXPECT_FALSE(SameHandle(s_unknown, out)); 606 EXPECT_TRUE(SameHandle(s_unknown_b, out)); 607 EXPECT_EQ("?", c.DebugString(out)); 608 609 // Relaxing with self returns self. 610 Relax(&c, s_1_2, s_1_2, &out); 611 EXPECT_TRUE(SameHandle(out, s_1_2)); 612 613 // Relaxing where one of the inputs has less information. 614 out = ShapeHandle(); 615 Relax(&c, s_1_2, s_u_2, &out); 616 EXPECT_FALSE(SameHandle(s_u_2, out)); 617 EXPECT_EQ("[?,2]", c.DebugString(out)); 618 out = ShapeHandle(); 619 Relax(&c, s_u_2, s_1_2, &out); 620 EXPECT_FALSE(SameHandle(s_u_2, out)); 621 EXPECT_EQ("[?,2]", c.DebugString(out)); 622 623 // Relaxing where each input has one distinct unknown dimension. 624 Relax(&c, s_u_2, s_1_u, &out); 625 EXPECT_EQ("[?,?]", c.DebugString(out)); 626 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 627 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1))); 628 auto s_u1 = c.UnknownShapeOfRank(1); 629 auto s_u2 = c.UnknownShapeOfRank(1); 630 Relax(&c, s_u1, s_u2, &out); 631 EXPECT_FALSE(SameHandle(s_u1, out)); 632 633 // Relaxing with mismatched values in a dimension returns a shape with that 634 // dimension unknown. 635 out = s_unknown; 636 Relax(&c, s_u_2, s_1_3, &out); 637 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 638 EXPECT_EQ("[?,?]", c.DebugString(out)); 639 out = s_unknown; 640 Relax(&c, s_1_3, s_u_2, &out); 641 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 642 EXPECT_EQ("[?,?]", c.DebugString(out)); 643 out = s_unknown; 644 645 // Relaxing with mismatched ranks returns a new unknown. 646 Relax(&c, s_1, s_1_2, &out); 647 EXPECT_EQ("?", c.DebugString(out)); 648 } 649 650 TEST_F(ShapeInferenceTest, MergeShape) { 651 NodeDef def; 652 InferenceContext c(kVersion, &def, MakeOpDef(7, 2), 653 {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), 654 Unknown(), S({1})}, 655 {}, {}, {}); 656 657 auto s_unknown = c.input(0); 658 auto s_1_2 = c.input(1); 659 auto s_u_2 = c.input(2); 660 auto s_1_u = c.input(3); 661 auto s_1_3 = c.input(4); 662 auto s_unknown_b = c.input(5); 663 auto s_1 = c.input(6); 664 ShapeHandle out; 665 666 // Merging any shape with unknown returns the shape. 667 EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok()); 668 EXPECT_TRUE(SameHandle(s_1_2, out)); 669 EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok()); 670 EXPECT_TRUE(SameHandle(s_u_2, out)); 671 EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok()); 672 EXPECT_TRUE(SameHandle(s_unknown, out)); 673 674 auto merged_shapes = c.MergedShapes(); 675 ASSERT_EQ(3, merged_shapes.size()); 676 EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown)); 677 EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2)); 678 EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2)); 679 EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown)); 680 EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown)); 681 EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b)); 682 683 // Merging with self returns self. 684 EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok()); 685 EXPECT_TRUE(SameHandle(out, s_1_2)); 686 687 merged_shapes = c.MergedShapes(); 688 EXPECT_EQ(3, merged_shapes.size()); 689 690 // Merging where one of the inputs is the right answer - return that input. 691 out = ShapeHandle(); 692 EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok()); 693 EXPECT_TRUE(SameHandle(s_1_2, out)); 694 out = ShapeHandle(); 695 EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok()); 696 EXPECT_TRUE(SameHandle(s_1_2, out)); 697 698 merged_shapes = c.MergedShapes(); 699 ASSERT_EQ(5, merged_shapes.size()); 700 EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2)); 701 EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2)); 702 EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2)); 703 EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2)); 704 705 // Merging where neither input is the right answer. 706 EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok()); 707 EXPECT_FALSE(SameHandle(out, s_u_2)); 708 EXPECT_FALSE(SameHandle(out, s_1_u)); 709 EXPECT_EQ("[1,2]", c.DebugString(out)); 710 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0))); 711 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1))); 712 713 merged_shapes = c.MergedShapes(); 714 ASSERT_EQ(7, merged_shapes.size()); 715 EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2)); 716 EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u)); 717 EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2)); 718 EXPECT_TRUE(merged_shapes[6].second.SameHandle(out)); 719 720 auto s_u1 = c.UnknownShapeOfRank(1); 721 auto s_u2 = c.UnknownShapeOfRank(1); 722 TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out)); 723 EXPECT_TRUE(SameHandle(s_u1, out)); 724 725 merged_shapes = c.MergedShapes(); 726 ASSERT_EQ(8, merged_shapes.size()); 727 EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1)); 728 EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2)); 729 730 // Incompatible merges give errors and set out to nullptr. 731 out = s_unknown; 732 EXPECT_TRUE( 733 StringPiece(c.Merge(s_u_2, s_1_3, &out).ToString()) 734 .contains( 735 "Invalid argument: Dimension 1 in both shapes must be equal, but " 736 "are 2 and 3")); 737 738 EXPECT_FALSE(IsSet(out)); 739 out = s_unknown; 740 EXPECT_TRUE( 741 StringPiece(c.Merge(s_1_3, s_u_2, &out).ToString()) 742 .contains( 743 "Invalid argument: Dimension 1 in both shapes must be equal, but " 744 "are 3 and 2")); 745 746 EXPECT_FALSE(IsSet(out)); 747 out = s_unknown; 748 EXPECT_TRUE( 749 StringPiece(c.Merge(s_1, s_1_2, &out).ToString()) 750 .contains( 751 "Invalid argument: Shapes must be equal rank, but are 1 and 2")); 752 753 EXPECT_FALSE(IsSet(out)); 754 755 merged_shapes = c.MergedShapes(); 756 EXPECT_EQ(8, merged_shapes.size()); 757 } 758 759 TEST_F(ShapeInferenceTest, MergePrefix) { 760 NodeDef def; 761 InferenceContext c(kVersion, &def, MakeOpDef(4, 2), 762 { 763 Unknown(), 764 S({-1, 2}), 765 S({1, -1, 3}), 766 S({2, 4}), 767 }, 768 {}, {}, {}); 769 770 auto s_unknown = c.input(0); 771 auto s_u_2 = c.input(1); 772 auto s_1_u_3 = c.input(2); 773 auto s_2_4 = c.input(3); 774 775 ShapeHandle s_out; 776 ShapeHandle s_prefix_out; 777 778 // Merging with unknown returns the inputs. 779 EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok()); 780 EXPECT_TRUE(SameHandle(s_out, s_unknown)); 781 EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2)); 782 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok()); 783 EXPECT_TRUE(SameHandle(s_out, s_1_u_3)); 784 EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown)); 785 786 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok()); 787 EXPECT_FALSE(SameHandle(s_out, s_1_u_3)); 788 EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out)); 789 EXPECT_EQ("[1,2,3]", c.DebugString(s_out)); 790 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0))); 791 EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0))); 792 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1))); 793 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1))); 794 795 // Incompatible merges give errors and set outs to nullptr. 796 s_out = s_unknown; 797 s_prefix_out = s_unknown; 798 EXPECT_TRUE( 799 StringPiece( 800 c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString()) 801 .contains( 802 "Invalid argument: Dimensions must be equal, but are 1 and 2")); 803 804 EXPECT_FALSE(IsSet(s_out)); 805 EXPECT_FALSE(IsSet(s_prefix_out)); 806 807 s_out = s_unknown; 808 s_prefix_out = s_unknown; 809 EXPECT_TRUE( 810 StringPiece( 811 c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString()) 812 .contains( 813 "Invalid argument: Shape must be at least rank 3 but is rank 2")); 814 EXPECT_FALSE(IsSet(s_out)); 815 EXPECT_FALSE(IsSet(s_prefix_out)); 816 } 817 818 TEST_F(ShapeInferenceTest, Subshape) { 819 NodeDef def; 820 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 821 {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}); 822 823 ShapeHandle unknown = c.input(1); 824 ShapeHandle out; 825 EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok()); 826 EXPECT_EQ("?", c.DebugString(out)); 827 EXPECT_TRUE(SameHandle(out, unknown)); 828 EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok()); 829 EXPECT_EQ("?", c.DebugString(out)); 830 EXPECT_FALSE(SameHandle(out, unknown)); 831 EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok()); 832 EXPECT_EQ("?", c.DebugString(out)); 833 EXPECT_FALSE(SameHandle(out, unknown)); 834 835 const int kFullRank = 5; 836 ShapeHandle out_arr[4]; 837 auto in0 = c.input(0); 838 EXPECT_TRUE(c.Subshape(in0, 0, &out).ok()); 839 EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out)); 840 EXPECT_TRUE(SameHandle(out, in0)); 841 EXPECT_EQ(kFullRank, c.Rank(out)); 842 for (int start = 0; start <= kFullRank + 1; ++start) { 843 for (int end = start; end <= kFullRank + 1; ++end) { 844 // Get subshapes using different start and end values that give the same 845 // range. 846 const int neg_start = 847 start >= kFullRank ? kFullRank : (start - kFullRank); 848 const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank); 849 ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok()); 850 ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok()); 851 ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok()); 852 ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok()); 853 854 // Verify all computed subshapes. 855 for (int arr_idx = 0; arr_idx < 4; ++arr_idx) { 856 out = out_arr[arr_idx]; 857 ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start), 858 c.Rank(out)) 859 << "start: " << start << " end: " << end << " arr_idx: " << arr_idx 860 << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out); 861 for (int d = 0; d < c.Rank(out); ++d) { 862 EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d))) 863 << "arr_idx: " << arr_idx; 864 } 865 } 866 } 867 } 868 869 // Errors. 870 out = unknown; 871 EXPECT_TRUE(StringPiece(c.Subshape(in0, 6, -3, &out).ToString()) 872 .contains("Invalid argument: Subshape must have computed " 873 "start <= end, but is 5 " 874 "and 2 (computed from start 6 and end -3 over " 875 "shape with rank 5)")); 876 EXPECT_FALSE(IsSet(out)); 877 out = unknown; 878 EXPECT_TRUE(StringPiece(c.Subshape(in0, -50, 100, &out).ToString()) 879 .contains("Invalid argument: Subshape start out of " 880 "bounds: -50, for shape with " 881 "rank 5")); 882 883 EXPECT_FALSE(IsSet(out)); 884 out = unknown; 885 EXPECT_TRUE(StringPiece(c.Subshape(in0, 0, -50, &out).ToString()) 886 .contains("Invalid argument: Subshape end out of bounds: " 887 "-50, for shape with rank " 888 "5")); 889 890 EXPECT_FALSE(IsSet(out)); 891 } 892 893 TEST_F(ShapeInferenceTest, Concatenate) { 894 NodeDef def; 895 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 896 {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}); 897 898 auto in0 = c.input(0); 899 auto in1 = c.input(1); 900 ShapeHandle unknown = c.input(2); 901 ShapeHandle out; 902 EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok()); 903 EXPECT_EQ("?", c.DebugString(out)); 904 EXPECT_FALSE(SameHandle(out, unknown)); 905 EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok()); 906 EXPECT_EQ("?", c.DebugString(out)); 907 EXPECT_FALSE(SameHandle(out, unknown)); 908 909 EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok()); 910 EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out)); 911 int out_i = 0; 912 for (int i = 0; i < c.Rank(in0); ++i, ++out_i) { 913 EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i))); 914 } 915 for (int i = 0; i < c.Rank(in1); ++i, ++out_i) { 916 EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i))); 917 } 918 } 919 920 TEST_F(ShapeInferenceTest, ReplaceDim) { 921 NodeDef def; 922 InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, 923 {}, {}, {}); 924 925 auto in = c.input(0); 926 auto unknown = c.input(1); 927 928 ShapeHandle replaced; 929 EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok()); 930 EXPECT_EQ("[2,2,3]", c.DebugString(replaced)); 931 EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok()); 932 EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); 933 EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok()); 934 EXPECT_EQ("[1,3,3]", c.DebugString(replaced)); 935 EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok()); 936 EXPECT_EQ("?", c.DebugString(replaced)); 937 938 // Negative indexing. 939 EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok()); 940 EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); 941 EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok()); 942 EXPECT_EQ("?", c.DebugString(replaced)); 943 944 // out of range indexing. 945 EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok()); 946 EXPECT_FALSE(IsSet(replaced)); 947 replaced = in; 948 EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok()); 949 EXPECT_FALSE(IsSet(replaced)); 950 } 951 952 TEST_F(ShapeInferenceTest, MakeShape) { 953 NodeDef def; 954 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, 955 {}, {}); 956 957 std::vector<DimensionHandle> dims; 958 auto in0 = c.input(0); 959 const int rank = c.Rank(in0); 960 dims.reserve(rank); 961 for (int i = 0; i < rank; ++i) { 962 dims.push_back(c.Dim(in0, rank - i - 1)); 963 } 964 965 auto s = c.MakeShape(dims); 966 EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s)); 967 EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1))); 968 969 auto s2 = c.MakeShape(dims); 970 EXPECT_FALSE(SameHandle(s, s2)); 971 EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1))); 972 973 auto s3 = c.MakeShape({1, 2, dims[2]}); 974 EXPECT_FALSE(SameHandle(s, s3)); 975 EXPECT_EQ("[1,2,3]", c.DebugString(s3)); 976 } 977 978 TEST_F(ShapeInferenceTest, UnknownShape) { 979 NodeDef def; 980 std::vector<ShapeHandle> empty; 981 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 982 983 auto u0 = c.UnknownShape(); 984 auto u1 = c.UnknownShape(); 985 EXPECT_EQ("?", c.DebugString(u0)); 986 EXPECT_EQ("?", c.DebugString(u1)); 987 EXPECT_FALSE(SameHandle(u0, u1)); 988 } 989 990 TEST_F(ShapeInferenceTest, KnownShapeToProto) { 991 NodeDef def; 992 std::vector<ShapeHandle> empty; 993 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 994 995 auto s = c.MakeShape({1, 2, 3}); 996 TensorShapeProto proto; 997 c.ShapeHandleToProto(s, &proto); 998 999 EXPECT_FALSE(proto.unknown_rank()); 1000 EXPECT_EQ(3, proto.dim_size()); 1001 EXPECT_EQ(1, proto.dim(0).size()); 1002 } 1003 1004 TEST_F(ShapeInferenceTest, UnknownShapeToProto) { 1005 NodeDef def; 1006 std::vector<ShapeHandle> empty; 1007 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1008 1009 auto u0 = c.UnknownShape(); 1010 TensorShapeProto proto; 1011 c.ShapeHandleToProto(u0, &proto); 1012 1013 EXPECT_TRUE(proto.unknown_rank()); 1014 EXPECT_EQ(0, proto.dim_size()); 1015 } 1016 1017 TEST_F(ShapeInferenceTest, Scalar) { 1018 NodeDef def; 1019 std::vector<ShapeHandle> empty; 1020 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1021 1022 auto s0 = c.Scalar(); 1023 EXPECT_EQ("[]", c.DebugString(s0)); 1024 auto s1 = c.Scalar(); 1025 EXPECT_EQ("[]", c.DebugString(s1)); 1026 } 1027 1028 TEST_F(ShapeInferenceTest, Vector) { 1029 NodeDef def; 1030 std::vector<ShapeHandle> empty; 1031 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1032 1033 auto s0 = c.Vector(1); 1034 EXPECT_EQ("[1]", c.DebugString(s0)); 1035 auto s1 = c.Vector(InferenceContext::kUnknownDim); 1036 EXPECT_EQ("[?]", c.DebugString(s1)); 1037 1038 auto d1 = c.UnknownDim(); 1039 auto s2 = c.Vector(d1); 1040 EXPECT_EQ("[?]", c.DebugString(s2)); 1041 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1042 } 1043 1044 TEST_F(ShapeInferenceTest, Matrix) { 1045 NodeDef def; 1046 std::vector<ShapeHandle> empty; 1047 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1048 1049 auto s0 = c.Matrix(1, 2); 1050 EXPECT_EQ("[1,2]", c.DebugString(s0)); 1051 auto s1 = c.Matrix(0, InferenceContext::kUnknownDim); 1052 EXPECT_EQ("[0,?]", c.DebugString(s1)); 1053 1054 auto d1 = c.UnknownDim(); 1055 auto d2 = c.UnknownDim(); 1056 auto s2 = c.Matrix(d1, d2); 1057 EXPECT_EQ("[?,?]", c.DebugString(s2)); 1058 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1059 EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1))); 1060 1061 auto s3 = c.Matrix(d1, 100); 1062 EXPECT_EQ("[?,100]", c.DebugString(s3)); 1063 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1064 } 1065 1066 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { 1067 auto create = [&](Tensor* t) { 1068 NodeDef def; 1069 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, 1070 {}); 1071 ShapeHandle out; 1072 Status s = c.MakeShapeFromShapeTensor(0, &out); 1073 if (s.ok()) { 1074 return c.DebugString(out); 1075 } else { 1076 EXPECT_FALSE(IsSet(out)); 1077 return s.error_message(); 1078 } 1079 }; 1080 1081 Tensor t; 1082 EXPECT_EQ("?", create(nullptr)); 1083 1084 t = ::tensorflow::test::AsTensor<int32>({1, 2, 3}); 1085 EXPECT_EQ("[1,2,3]", create(&t)); 1086 1087 t = ::tensorflow::test::AsTensor<int64>({3, 2, 1}); 1088 EXPECT_EQ("[3,2,1]", create(&t)); 1089 1090 t = ::tensorflow::test::AsTensor<int64>({3, -1, 1}); 1091 EXPECT_EQ("[3,?,1]", create(&t)); 1092 1093 t = ::tensorflow::test::AsTensor<int64>({}); 1094 EXPECT_EQ("[]", create(&t)); 1095 1096 t = ::tensorflow::test::AsTensor<float>({1, 2, 3}); 1097 EXPECT_TRUE( 1098 StringPiece(create(&t)) 1099 .contains("Input tensor must be int32 or int64, but was float")); 1100 1101 t = ::tensorflow::test::AsScalar<int32>(1); 1102 EXPECT_TRUE(StringPiece(create(&t)) 1103 .contains("Input tensor must be rank 1, but was rank 0")); 1104 1105 t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1}); 1106 EXPECT_TRUE(StringPiece(create(&t)) 1107 .contains("Input tensor must be rank 1, but was rank 2")); 1108 1109 // Test negative values for the dims. 1110 t = ::tensorflow::test::AsTensor<int64>({3, -2, 1}); 1111 EXPECT_TRUE(StringPiece(create(&t)) 1112 .contains("Invalid value in tensor used for shape: -2")); 1113 1114 // Test negative values for the dims. 1115 t = ::tensorflow::test::AsTensor<int32>({3, -2, 1}); 1116 EXPECT_TRUE(StringPiece(create(&t)) 1117 .contains("Invalid value in tensor used for shape: -2")); 1118 1119 // Test when the input shape is wrong. 1120 { 1121 NodeDef def; 1122 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, 1123 {}, {}); 1124 ShapeHandle out; 1125 EXPECT_EQ("Shape must be rank 1 but is rank 2", 1126 c.MakeShapeFromShapeTensor(0, &out).error_message()); 1127 } 1128 } 1129 1130 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { 1131 NodeDef def; 1132 std::vector<ShapeHandle> empty; 1133 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1134 1135 // With an unknown rank. 1136 ShapeHandle out; 1137 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out)); 1138 EXPECT_EQ("?", c.DebugString(out)); 1139 1140 // With a known rank. 1141 TF_ASSERT_OK( 1142 c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out)); 1143 EXPECT_EQ("[0]", c.DebugString(out)); 1144 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape( 1145 PartialTensorShape({0, -1, 1000}), &out)); 1146 EXPECT_EQ("[0,?,1000]", c.DebugString(out)); 1147 } 1148 1149 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { 1150 NodeDef def; 1151 std::vector<ShapeHandle> empty; 1152 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1153 1154 ShapeHandle out; 1155 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out)); 1156 EXPECT_EQ("[]", c.DebugString(out)); 1157 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out)); 1158 EXPECT_EQ("[0]", c.DebugString(out)); 1159 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out)); 1160 EXPECT_EQ("[0,7,1000]", c.DebugString(out)); 1161 } 1162 1163 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { 1164 NodeDef def; 1165 std::vector<ShapeHandle> empty; 1166 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1167 TensorShapeProto proto; 1168 1169 // With a set unknown rank. 1170 ShapeHandle out; 1171 proto.set_unknown_rank(true); 1172 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1173 EXPECT_EQ("?", c.DebugString(out)); 1174 proto.add_dim()->set_size(0); 1175 EXPECT_TRUE( 1176 StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) 1177 .contains("An unknown shape must not have any dimensions set.")); 1178 EXPECT_FALSE(IsSet(out)); 1179 1180 // With known rank. 1181 proto.set_unknown_rank(false); 1182 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1183 EXPECT_EQ("[0]", c.DebugString(out)); 1184 proto.add_dim()->set_size(-1); 1185 proto.add_dim()->set_size(1000); 1186 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1187 EXPECT_EQ("[0,?,1000]", c.DebugString(out)); 1188 1189 // With invalid dimension value. 1190 proto.add_dim()->set_size(-2); 1191 EXPECT_TRUE( 1192 StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) 1193 .contains("Shape [0,?,1000,-2] has dimensions with values below -1 " 1194 "(where -1 means unknown)")); 1195 1196 EXPECT_FALSE(IsSet(out)); 1197 } 1198 1199 TEST_F(ShapeInferenceTest, MakeDim) { 1200 NodeDef def; 1201 std::vector<ShapeHandle> empty; 1202 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1203 1204 auto d0 = c.MakeDim(1); 1205 auto d1 = c.MakeDim(1); 1206 auto d2 = c.MakeDim(2); 1207 EXPECT_EQ("1", c.DebugString(d0)); 1208 EXPECT_EQ("1", c.DebugString(d1)); 1209 EXPECT_FALSE(SameHandle(d0, d1)); 1210 EXPECT_EQ("2", c.DebugString(d2)); 1211 } 1212 1213 TEST_F(ShapeInferenceTest, UnknownDim) { 1214 NodeDef def; 1215 std::vector<ShapeHandle> empty; 1216 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1217 1218 auto d0 = c.UnknownDim(); 1219 auto d1 = c.UnknownDim(); 1220 EXPECT_EQ("?", c.DebugString(d0)); 1221 EXPECT_EQ("?", c.DebugString(d1)); 1222 EXPECT_FALSE(SameHandle(d0, d1)); 1223 } 1224 1225 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { 1226 NodeDef def; 1227 std::vector<ShapeHandle> empty; 1228 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1229 1230 auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); 1231 EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); 1232 1233 auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0); 1234 EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0)); 1235 } 1236 1237 TEST_F(ShapeInferenceTest, InputTensors) { 1238 const Tensor t1 = tensorflow::test::AsTensor<float>({10}); 1239 const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); 1240 NodeDef def; 1241 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, 1242 {&t1, &t2}, {}, {}); 1243 1244 EXPECT_TRUE(c.input_tensor(0) == &t1); 1245 EXPECT_TRUE(c.input_tensor(1) == &t2); 1246 EXPECT_TRUE(c.input_tensor(2) == nullptr); 1247 } 1248 1249 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { 1250 Tensor t1 = tensorflow::test::AsScalar<int32>(20); 1251 Tensor t2 = tensorflow::test::AsScalar<int32>(-1); 1252 NodeDef def; 1253 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, 1254 {&t1, &t2}, {}, {}); 1255 1256 DimensionHandle d; 1257 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); 1258 EXPECT_EQ("20", c.DebugString(d)); 1259 1260 EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) 1261 .contains("Dimension size, given by scalar input 1, must " 1262 "be non-negative but is -1")); 1263 1264 // Same tests, with int64 values. 1265 t1 = tensorflow::test::AsScalar<int64>(20); 1266 t2 = tensorflow::test::AsScalar<int64>(-1); 1267 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); 1268 EXPECT_EQ("20", c.DebugString(d)); 1269 1270 EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) 1271 .contains("Dimension size, given by scalar input 1, must " 1272 "be non-negative but is -1")); 1273 } 1274 1275 TEST_F(ShapeInferenceTest, GetAttr) { 1276 OpRegistrationData op_reg_data; 1277 op_reg_data.op_def = MakeOpDef(0, 2); 1278 NodeDef def; 1279 CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def) 1280 .Attr("foo", "bar") 1281 .Finalize(&def) 1282 .ok()); 1283 1284 std::vector<ShapeHandle> empty; 1285 InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}); 1286 string value; 1287 EXPECT_TRUE(c.GetAttr("foo", &value).ok()); 1288 EXPECT_EQ("bar", value); 1289 } 1290 1291 TEST_F(ShapeInferenceTest, Divide) { 1292 NodeDef def; 1293 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, 1294 {}, {}); 1295 1296 auto s = c.input(0); 1297 auto d_6 = c.Dim(s, 0); 1298 auto d_unknown = c.Dim(s, 1); 1299 auto d_1 = c.Dim(s, 2); 1300 auto d_2 = c.Dim(s, 3); 1301 auto d_0 = c.Dim(s, 4); 1302 bool evenly_divisible = true; 1303 1304 // Dividing unknown by non-1 gives new unknown. 1305 DimensionHandle out; 1306 EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok()); 1307 EXPECT_EQ("?", c.DebugString(out)); 1308 EXPECT_FALSE(SameHandle(out, d_unknown)); 1309 1310 // Dividing anything by 1 returns the input. 1311 EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok()); 1312 EXPECT_TRUE(SameHandle(out, d_unknown)); 1313 EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok()); 1314 EXPECT_TRUE(SameHandle(out, d_6)); 1315 EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok()); 1316 EXPECT_TRUE(SameHandle(out, d_unknown)); 1317 EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok()); 1318 EXPECT_TRUE(SameHandle(out, d_6)); 1319 1320 EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok()); 1321 EXPECT_EQ("3", c.DebugString(out)); 1322 EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok()); 1323 EXPECT_EQ("3", c.DebugString(out)); 1324 1325 EXPECT_TRUE( 1326 StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message()) 1327 .contains("Dimension size must be evenly divisible by 5 but is 6")); 1328 1329 EXPECT_TRUE( 1330 StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) 1331 .contains("Divisor must be positive but is 0")); 1332 EXPECT_TRUE( 1333 StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message()) 1334 .contains("Divisor must be positive but is 0")); 1335 1336 EXPECT_TRUE( 1337 StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) 1338 .contains("Divisor must be positive but is -1")); 1339 1340 // Repeat error cases above with evenly_divisible=false. 1341 evenly_divisible = false; 1342 EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok()); 1343 EXPECT_EQ("1", c.DebugString(out)); 1344 1345 EXPECT_TRUE( 1346 StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) 1347 .contains("Divisor must be positive but is 0")); 1348 1349 EXPECT_TRUE( 1350 StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) 1351 .contains("Divisor must be positive but is -1")); 1352 } 1353 1354 TEST_F(ShapeInferenceTest, Add) { 1355 NodeDef def; 1356 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, 1357 {}); 1358 1359 auto s = c.input(0); 1360 auto d_6 = c.Dim(s, 0); 1361 auto d_unknown = c.Dim(s, 1); 1362 auto d_0 = c.Dim(s, 2); 1363 1364 // Adding non-zero to unknown gives new unknown. 1365 DimensionHandle out; 1366 EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok()); 1367 EXPECT_EQ("?", c.DebugString(out)); 1368 EXPECT_FALSE(SameHandle(out, d_unknown)); 1369 1370 // Adding 0 to anything gives input. 1371 EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok()); 1372 EXPECT_TRUE(SameHandle(out, d_unknown)); 1373 EXPECT_TRUE(c.Add(d_6, 0, &out).ok()); 1374 EXPECT_TRUE(SameHandle(out, d_6)); 1375 1376 // Adding dimension with value 0 to anything gives input. 1377 EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok()); 1378 EXPECT_TRUE(SameHandle(out, d_unknown)); 1379 EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok()); 1380 EXPECT_TRUE(SameHandle(out, d_6)); 1381 1382 // Test addition. 1383 EXPECT_TRUE(c.Add(d_6, 2, &out).ok()); 1384 EXPECT_EQ("8", c.DebugString(out)); 1385 EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok()); 1386 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); 1387 1388 // Test addition using dimension as second value. 1389 EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok()); 1390 EXPECT_EQ("8", c.DebugString(out)); 1391 EXPECT_TRUE( 1392 c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok()); 1393 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); 1394 EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok()); 1395 EXPECT_EQ("?", c.DebugString(out)); 1396 EXPECT_TRUE(c.Add(d_0, d_6, &out).ok()); 1397 EXPECT_TRUE(SameHandle(out, d_6)); 1398 1399 EXPECT_TRUE( 1400 StringPiece(c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out) 1401 .error_message()) 1402 .contains( 1403 "Dimension size overflow from adding 6 and 9223372036854775802")); 1404 } 1405 1406 TEST_F(ShapeInferenceTest, Subtract) { 1407 NodeDef def; 1408 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, 1409 {}, {}); 1410 1411 auto s = c.input(0); 1412 auto d_6 = c.Dim(s, 0); 1413 auto d_unknown = c.Dim(s, 1); 1414 auto d_0 = c.Dim(s, 2); 1415 auto d_5 = c.Dim(s, 3); 1416 1417 // Subtracting non-zero from unknown gives new unknown. 1418 DimensionHandle out; 1419 EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok()); 1420 EXPECT_EQ("?", c.DebugString(out)); 1421 EXPECT_FALSE(SameHandle(out, d_unknown)); 1422 1423 // Subtracting 0 from anything gives input. 1424 EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok()); 1425 EXPECT_TRUE(SameHandle(out, d_unknown)); 1426 EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok()); 1427 EXPECT_TRUE(SameHandle(out, d_6)); 1428 1429 // Subtracting dimension with value 0 from anything gives input. 1430 EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok()); 1431 EXPECT_TRUE(SameHandle(out, d_unknown)); 1432 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok()); 1433 EXPECT_TRUE(SameHandle(out, d_6)); 1434 1435 // Test subtraction. 1436 EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok()); 1437 EXPECT_EQ("4", c.DebugString(out)); 1438 EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok()); 1439 EXPECT_EQ("0", c.DebugString(out)); 1440 1441 // Test subtraction using dimension as second value. 1442 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok()); 1443 EXPECT_EQ("4", c.DebugString(out)); 1444 EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok()); 1445 EXPECT_EQ("1", c.DebugString(out)); 1446 EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok()); 1447 EXPECT_EQ("?", c.DebugString(out)); 1448 EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok()); 1449 EXPECT_TRUE(SameHandle(out, d_6)); 1450 1451 EXPECT_TRUE( 1452 StringPiece(c.Subtract(d_5, d_6, &out).error_message()) 1453 .contains("Negative dimension size caused by subtracting 6 from 5")); 1454 } 1455 1456 TEST_F(ShapeInferenceTest, Multiply) { 1457 NodeDef def; 1458 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, 1459 {}, {}); 1460 1461 auto s = c.input(0); 1462 auto d_6 = c.Dim(s, 0); 1463 auto d_unknown = c.Dim(s, 1); 1464 auto d_0 = c.Dim(s, 2); 1465 auto d_1 = c.Dim(s, 3); 1466 1467 // Multiplying non-zero to unknown gives new unknown. 1468 DimensionHandle out; 1469 EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok()); 1470 EXPECT_EQ("?", c.DebugString(out)); 1471 1472 // Multiplying 0 to anything gives 0. 1473 EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok()); 1474 EXPECT_EQ("0", c.DebugString(out)); 1475 EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok()); 1476 EXPECT_EQ("0", c.DebugString(out)); 1477 EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok()); 1478 EXPECT_EQ("0", c.DebugString(out)); 1479 1480 // Multiplying 1 to anything gives the original. 1481 // (unknown -> unknown) 1482 EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); 1483 EXPECT_TRUE(SameHandle(d_unknown, out)); 1484 EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok()); 1485 EXPECT_TRUE(SameHandle(d_unknown, out)); 1486 EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok()); 1487 EXPECT_TRUE(SameHandle(d_unknown, out)); 1488 // (known -> known) 1489 EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok()); 1490 EXPECT_TRUE(SameHandle(d_6, out)); 1491 EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok()); 1492 EXPECT_TRUE(SameHandle(d_6, out)); 1493 EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok()); 1494 EXPECT_TRUE(SameHandle(d_6, out)); 1495 1496 // Test multiplication. 1497 EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok()); 1498 EXPECT_EQ("12", c.DebugString(out)); 1499 EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok()); 1500 EXPECT_EQ("36", c.DebugString(out)); 1501 1502 // Test multiplication using dimension as second value. 1503 EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok()); 1504 EXPECT_EQ("12", c.DebugString(out)); 1505 EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok()); 1506 EXPECT_EQ("?", c.DebugString(out)); 1507 } 1508 1509 TEST_F(ShapeInferenceTest, FullyDefined) { 1510 NodeDef def; 1511 std::vector<ShapeHandle> empty; 1512 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1513 1514 // No rank or missing dimension information should return false. 1515 EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); 1516 EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim()))); 1517 1518 // Return true if all information exists. 1519 EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2)))); 1520 EXPECT_TRUE(c.FullyDefined(c.Scalar())); 1521 } 1522 1523 TEST_F(ShapeInferenceTest, Min) { 1524 NodeDef def; 1525 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, 1526 {}, {}); 1527 1528 auto s = c.input(0); 1529 auto d_1 = c.Dim(s, 0); 1530 auto d_2 = c.Dim(s, 1); 1531 auto d_unknown = c.Dim(s, 2); 1532 auto d_0 = c.Dim(s, 3); 1533 1534 // Minimum involving zero and unknown returns zero. 1535 DimensionHandle out; 1536 EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok()); 1537 EXPECT_TRUE(SameHandle(d_0, out)); 1538 EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok()); 1539 EXPECT_TRUE(SameHandle(d_0, out)); 1540 EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok()); 1541 EXPECT_EQ("0", c.DebugString(out)); 1542 EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok()); 1543 EXPECT_EQ("0", c.DebugString(out)); 1544 1545 // Minimum involving unknowns and non-zeros gives new unknown. 1546 EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok()); 1547 EXPECT_EQ("?", c.DebugString(out)); 1548 EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok()); 1549 EXPECT_EQ("?", c.DebugString(out)); 1550 EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok()); 1551 EXPECT_EQ("?", c.DebugString(out)); 1552 1553 // Minimum with constant second arg. 1554 EXPECT_TRUE(c.Min(d_1, 1, &out).ok()); 1555 EXPECT_TRUE(SameHandle(d_1, out)); 1556 EXPECT_TRUE(c.Min(d_1, 3, &out).ok()); 1557 EXPECT_TRUE(SameHandle(d_1, out)); 1558 EXPECT_TRUE(c.Min(d_2, 1, &out).ok()); 1559 EXPECT_EQ("1", c.DebugString(out)); 1560 1561 // Minimum with two dimensions. 1562 EXPECT_TRUE(c.Min(d_1, d_1, &out).ok()); 1563 EXPECT_TRUE(SameHandle(d_1, out)); 1564 EXPECT_TRUE(c.Min(d_1, d_2, &out).ok()); 1565 EXPECT_TRUE(SameHandle(d_1, out)); 1566 EXPECT_TRUE(c.Min(d_2, d_1, &out).ok()); 1567 EXPECT_TRUE(SameHandle(d_1, out)); 1568 EXPECT_TRUE(c.Min(d_2, d_2, &out).ok()); 1569 EXPECT_TRUE(SameHandle(d_2, out)); 1570 } 1571 1572 TEST_F(ShapeInferenceTest, Max) { 1573 NodeDef def; 1574 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, 1575 {}); 1576 1577 auto s = c.input(0); 1578 auto d_1 = c.Dim(s, 0); 1579 auto d_2 = c.Dim(s, 1); 1580 auto d_unknown = c.Dim(s, 2); 1581 1582 // Maximum involving unknowns gives new unknown. 1583 DimensionHandle out; 1584 EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok()); 1585 EXPECT_EQ("?", c.DebugString(out)); 1586 EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok()); 1587 EXPECT_EQ("?", c.DebugString(out)); 1588 EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok()); 1589 EXPECT_EQ("?", c.DebugString(out)); 1590 1591 // Maximum with constant second arg. 1592 EXPECT_TRUE(c.Max(d_1, 1, &out).ok()); 1593 EXPECT_TRUE(SameHandle(d_1, out)); 1594 EXPECT_TRUE(c.Max(d_2, 1, &out).ok()); 1595 EXPECT_TRUE(SameHandle(d_2, out)); 1596 EXPECT_TRUE(c.Max(d_2, 3, &out).ok()); 1597 EXPECT_EQ("3", c.DebugString(out)); 1598 1599 // Maximum with two dimensions. 1600 EXPECT_TRUE(c.Max(d_1, d_1, &out).ok()); 1601 EXPECT_TRUE(SameHandle(d_1, out)); 1602 EXPECT_TRUE(c.Max(d_1, d_2, &out).ok()); 1603 EXPECT_TRUE(SameHandle(d_2, out)); 1604 EXPECT_TRUE(c.Max(d_2, d_1, &out).ok()); 1605 EXPECT_TRUE(SameHandle(d_2, out)); 1606 EXPECT_TRUE(c.Max(d_2, d_2, &out).ok()); 1607 EXPECT_TRUE(SameHandle(d_2, out)); 1608 } 1609 1610 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { 1611 NodeDef def; 1612 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, 1613 {}); 1614 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { 1615 ShapeHandle s; 1616 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); 1617 return s; 1618 }; 1619 auto get_shapes_and_types_from_context = [&](int idx) { 1620 if (input_not_output) { 1621 return c.input_handle_shapes_and_types(idx); 1622 } else { 1623 return c.output_handle_shapes_and_types(idx); 1624 } 1625 }; 1626 auto merge_shapes_and_types_to_context = 1627 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) { 1628 if (input_not_output) { 1629 return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types); 1630 } else { 1631 return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types); 1632 } 1633 }; 1634 1635 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); 1636 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); 1637 1638 // First merge will take the input completely. 1639 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT}, 1640 {c.UnknownShape(), DT_INVALID}, 1641 {make_shape({4, 3, 2, 1}), DT_INT32}}; 1642 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1643 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); 1644 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0); 1645 ASSERT_EQ(3, v.size()); 1646 for (int i = 0; i < v.size(); ++i) { 1647 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1648 EXPECT_EQ(t[i].dtype, v[i].dtype); 1649 } 1650 1651 // Merge that fails because wrong number of values passed. 1652 // Fails, and no changes made. 1653 ASSERT_FALSE(merge_shapes_and_types_to_context( 1654 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}})); 1655 v = *get_shapes_and_types_from_context(0); 1656 ASSERT_EQ(3, v.size()); 1657 for (int i = 0; i < v.size(); ++i) { 1658 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1659 EXPECT_EQ(t[i].dtype, v[i].dtype); 1660 } 1661 1662 // Only difference is in a mismatched shape. That is ignored, 1663 // and there are no other changes, so nothing is done. 1664 // 1665 // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to 1666 // return an error (separate error from 'refined' output)? 1667 auto t2 = t; 1668 t2[2].shape = make_shape({4, 3, 4, 1}); 1669 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); 1670 v = *get_shapes_and_types_from_context(0); 1671 ASSERT_EQ(3, v.size()); 1672 for (int i = 0; i < v.size(); ++i) { 1673 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1674 EXPECT_EQ(t[i].dtype, v[i].dtype); 1675 } 1676 1677 // Only difference is in a mismatched dtype, but that cannot be 1678 // updated unless original dtype is DT_INVALID. 1679 t2 = t; 1680 t2[2].dtype = DT_FLOAT; 1681 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); 1682 v = *get_shapes_and_types_from_context(0); 1683 ASSERT_EQ(3, v.size()); 1684 for (int i = 0; i < v.size(); ++i) { 1685 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1686 EXPECT_EQ(t[i].dtype, v[i].dtype); 1687 } 1688 1689 // Difference is mergeable (new shape). 1690 t[1].shape = make_shape({1, 10}); 1691 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1692 v = *get_shapes_and_types_from_context(0); 1693 ASSERT_EQ(3, v.size()); 1694 for (int i = 0; i < v.size(); ++i) { 1695 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1696 EXPECT_EQ(t[i].dtype, v[i].dtype); 1697 } 1698 1699 // Difference is mergeable (new type). 1700 t[1].dtype = DT_DOUBLE; 1701 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1702 v = *get_shapes_and_types_from_context(0); 1703 ASSERT_EQ(3, v.size()); 1704 for (int i = 0; i < v.size(); ++i) { 1705 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1706 EXPECT_EQ(t[i].dtype, v[i].dtype); 1707 } 1708 1709 // No difference. 1710 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t)); 1711 } 1712 1713 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) { 1714 TestMergeHandles(true /* input_not_output */); 1715 } 1716 1717 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) { 1718 TestMergeHandles(false /* input_not_output */); 1719 } 1720 1721 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { 1722 NodeDef def; 1723 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, 1724 {}); 1725 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { 1726 ShapeHandle s; 1727 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); 1728 return s; 1729 }; 1730 auto get_shapes_and_types_from_context = [&](int idx) { 1731 if (input_not_output) { 1732 return c.input_handle_shapes_and_types(idx); 1733 } else { 1734 return c.output_handle_shapes_and_types(idx); 1735 } 1736 }; 1737 auto relax_shapes_and_types_to_context = 1738 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) { 1739 if (input_not_output) { 1740 return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types); 1741 } else { 1742 return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types); 1743 } 1744 }; 1745 1746 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); 1747 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); 1748 1749 // First relax will take the input completely. 1750 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT}, 1751 {c.UnknownShape(), DT_INVALID}, 1752 {make_shape({4, 3, 2, 1}), DT_INT32}}; 1753 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1754 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); 1755 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0); 1756 ASSERT_EQ(3, v.size()); 1757 for (int i = 0; i < v.size(); ++i) { 1758 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1759 EXPECT_EQ(t[i].dtype, v[i].dtype); 1760 } 1761 1762 // Relax that fails because wrong number of values passed. 1763 // Fails, and no changes made. 1764 ASSERT_FALSE(relax_shapes_and_types_to_context( 1765 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}})); 1766 v = *get_shapes_and_types_from_context(0); 1767 ASSERT_EQ(3, v.size()); 1768 for (int i = 0; i < v.size(); ++i) { 1769 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1770 EXPECT_EQ(t[i].dtype, v[i].dtype); 1771 } 1772 1773 // Only difference is in a mismatched shape. This should replace 1774 // the mismatched dimension with an UnknownDim. 1775 auto t2 = t; 1776 t2[2].shape = make_shape({4, 3, 4, 1}); 1777 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2)); 1778 v = *get_shapes_and_types_from_context(0); 1779 EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape)); 1780 for (int i = 0; i < v.size(); ++i) { 1781 EXPECT_EQ(t[i].dtype, v[i].dtype); 1782 } 1783 1784 // Only difference is in a mismatched dtype, but that cannot be 1785 // updated unless original dtype is DT_INVALID. 1786 t2 = t; 1787 t2[2].dtype = DT_FLOAT; 1788 ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2)); 1789 v = *get_shapes_and_types_from_context(0); 1790 ASSERT_EQ(3, v.size()); 1791 for (int i = 0; i < v.size(); ++i) { 1792 EXPECT_EQ(t[i].dtype, v[i].dtype); 1793 } 1794 1795 // Difference is a new shape, which will result in a new UnknownShape. 1796 t[1].shape = make_shape({1, 10}); 1797 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1798 v = *get_shapes_and_types_from_context(0); 1799 ASSERT_EQ(3, v.size()); 1800 EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape)); 1801 EXPECT_EQ("?", c.DebugString(v[1].shape)); 1802 for (int i = 0; i < v.size(); ++i) { 1803 EXPECT_EQ(t[i].dtype, v[i].dtype); 1804 } 1805 1806 // Difference is relaxable (new type). 1807 t[1].dtype = DT_DOUBLE; 1808 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1809 v = *get_shapes_and_types_from_context(0); 1810 EXPECT_EQ(t[1].dtype, v[1].dtype); 1811 } 1812 1813 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) { 1814 TestRelaxHandles(true /* input_not_output */); 1815 } 1816 1817 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) { 1818 TestRelaxHandles(false /* input_not_output */); 1819 } 1820 1821 } // namespace shape_inference 1822 } // namespace tensorflow 1823