1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/shape_inference.h" 16 17 #include "tensorflow/core/framework/fake_input.h" 18 #include "tensorflow/core/framework/node_def_builder.h" 19 #include "tensorflow/core/framework/op_def_builder.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/framework/types.pb.h" 23 #include "tensorflow/core/lib/core/status_test_util.h" 24 #include "tensorflow/core/lib/strings/str_util.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 #include "tensorflow/core/platform/test.h" 27 28 namespace tensorflow { 29 namespace shape_inference { 30 namespace { 31 32 OpDef MakeOpDefWithLists() { 33 OpRegistrationData op_reg_data; 34 OpDefBuilder b("dummy"); 35 b.Input(strings::StrCat("input: N * float")); 36 b.Output(strings::StrCat("output: N * float")); 37 CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok()); 38 return op_reg_data.op_def; 39 } 40 41 PartialTensorShape S(std::initializer_list<int64> dims) { 42 return PartialTensorShape(dims); 43 } 44 45 PartialTensorShape Unknown() { return PartialTensorShape(); } 46 47 } // namespace 48 49 class ShapeInferenceTest : public ::testing::Test { 50 protected: 51 // These give access to private functions of DimensionHandle and ShapeHandle. 52 bool SameHandle(DimensionHandle a, DimensionHandle b) { 53 return a.SameHandle(b); 54 } 55 bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); } 56 bool IsSet(DimensionHandle d) { return d.IsSet(); } 57 bool IsSet(ShapeHandle s) { return s.IsSet(); } 58 void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1, 59 DimensionHandle* out) { 60 c->Relax(d0, d1, out); 61 } 62 void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1, 63 ShapeHandle* out) { 64 c->Relax(s0, s1, out); 65 } 66 void TestMergeHandles(bool input_not_output); 67 void TestRelaxHandles(bool input_not_output); 68 69 static const int kVersion = 0; // used for graph-def version. 70 }; 71 72 TEST_F(ShapeInferenceTest, InputOutputByName) { 73 // Setup test to contain an input tensor list of size 3. 74 OpDef op_def = MakeOpDefWithLists(); 75 NodeDef def; 76 auto s = NodeDefBuilder("dummy", &op_def) 77 .Attr("N", 3) 78 .Input(FakeInput(DT_FLOAT)) 79 .Finalize(&def); 80 InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, 81 {}, {}, {}); 82 83 EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); 84 EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1)))); 85 EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2)))); 86 // Test getters. 87 std::vector<ShapeHandle> shapes; 88 EXPECT_FALSE(c.input("nonexistent", &shapes).ok()); 89 TF_EXPECT_OK(c.input("input", &shapes)); 90 EXPECT_EQ("[1,5]", c.DebugString(shapes[0])); 91 EXPECT_EQ("[2,5]", c.DebugString(shapes[1])); 92 EXPECT_EQ("[1,3]", c.DebugString(shapes[2])); 93 94 // Test setters. 95 EXPECT_FALSE(c.set_output("nonexistent", shapes).ok()); 96 TF_EXPECT_OK(c.set_output("output", shapes)); 97 EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0)))); 98 EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1)))); 99 EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2)))); 100 } 101 102 static OpDef MakeOpDef(int num_inputs, int num_outputs) { 103 OpRegistrationData op_reg_data; 104 OpDefBuilder b("dummy"); 105 for (int i = 0; i < num_inputs; ++i) { 106 b.Input(strings::StrCat("i", i, ": float")); 107 } 108 for (int i = 0; i < num_outputs; ++i) { 109 b.Output(strings::StrCat("o", i, ": float")); 110 } 111 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); 112 return op_reg_data.op_def; 113 } 114 115 TEST_F(ShapeInferenceTest, DimensionOrConstant) { 116 NodeDef def; 117 InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}); 118 EXPECT_EQ(InferenceContext::kUnknownDim, 119 c.Value(InferenceContext::kUnknownDim)); 120 EXPECT_EQ(1, c.Value(1)); 121 122 #ifndef NDEBUG 123 // Only run death test if DCHECKS are enabled. 124 EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to"); 125 #endif 126 } 127 128 TEST_F(ShapeInferenceTest, Run) { 129 NodeDef def; 130 def.set_name("foo"); 131 def.set_op("foo_op"); 132 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}); 133 TF_ASSERT_OK(c.construction_status()); 134 135 { 136 auto fn = [](InferenceContext* c) { 137 ShapeHandle h; 138 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h)); 139 c->set_output(0, c->input(0)); 140 c->set_output(1, c->input(0)); 141 return Status::OK(); 142 }; 143 TF_ASSERT_OK(c.Run(fn)); 144 } 145 146 { 147 auto fn = [](InferenceContext* c) { 148 ShapeHandle h; 149 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 150 c->set_output(0, c->input(0)); 151 c->set_output(1, c->input(0)); 152 return Status::OK(); 153 }; 154 Status s = c.Run(fn); 155 // Extra error message is attached when Run fails. 156 EXPECT_TRUE(str_util::StrContains( 157 s.ToString(), 158 "Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')")) 159 << s; 160 } 161 } 162 163 // Tests different context data added when Run returns error. 164 TEST_F(ShapeInferenceTest, AttachContext) { 165 NodeDef def; 166 def.set_name("foo"); 167 def.set_op("foo_op"); 168 // Error when no constant tensors were requested. 169 { 170 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, 171 {}); 172 TF_ASSERT_OK(c.construction_status()); 173 auto fn = [](InferenceContext* c) { 174 ShapeHandle h; 175 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 176 c->set_output(0, c->input(0)); 177 return Status::OK(); 178 }; 179 EXPECT_EQ( 180 "Invalid argument: Shape must be at most rank 0 but is rank 3 for " 181 "'foo' (op: 'foo_op') with input shapes: [1,2,3].", 182 c.Run(fn).ToString()); 183 } 184 185 // Error when a constant tensor value was requested. 186 { 187 Tensor input_t = 188 ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5}); 189 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 190 {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {}); 191 TF_ASSERT_OK(c.construction_status()); 192 auto fn = [](InferenceContext* c) { 193 c->input_tensor(0); // get this one, but it's null - won't be in error. 194 c->input_tensor(1); // get this one, will now be in error. 195 ShapeHandle h; 196 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 197 c->set_output(0, c->input(0)); 198 return Status::OK(); 199 }; 200 EXPECT_EQ( 201 "Invalid argument: Shape must be at most rank 0 but is rank 3 for " 202 "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with " 203 "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.", 204 c.Run(fn).ToString()); 205 } 206 207 // Error when a constant tensor value as shape was requested, but no partial 208 // shapes provided. 209 { 210 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); 211 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, 212 {nullptr, &input_t}, {}, {}); 213 TF_ASSERT_OK(c.construction_status()); 214 auto fn = [](InferenceContext* c) { 215 ShapeHandle s; 216 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 217 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 218 ShapeHandle h; 219 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 220 c->set_output(0, c->input(0)); 221 return Status::OK(); 222 }; 223 EXPECT_EQ( 224 "Invalid argument: Shape must be at most rank 0 but is rank 1 for " 225 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " 226 "input tensors: input[1] = <1 2 3 4 5>.", 227 c.Run(fn).ToString()); 228 } 229 230 // Error when a constant tensor value as shape was requested, and a partial 231 // shape was provided. 232 { 233 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); 234 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, 235 {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}); 236 TF_ASSERT_OK(c.construction_status()); 237 auto fn = [](InferenceContext* c) { 238 ShapeHandle s; 239 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 240 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 241 ShapeHandle h; 242 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); 243 c->set_output(0, c->input(0)); 244 return Status::OK(); 245 }; 246 EXPECT_EQ( 247 "Invalid argument: Shape must be at most rank 0 but is rank 1 for " 248 "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " 249 "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed " 250 "as partial shapes: input[0] = [10,?,5].", 251 c.Run(fn).ToString()); 252 } 253 } 254 255 TEST_F(ShapeInferenceTest, RankAndDimInspection) { 256 NodeDef def; 257 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 258 {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}); 259 EXPECT_EQ(3, c.num_inputs()); 260 EXPECT_EQ(2, c.num_outputs()); 261 262 auto in0 = c.input(0); 263 EXPECT_EQ("?", c.DebugString(in0)); 264 EXPECT_FALSE(c.RankKnown(in0)); 265 EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0)); 266 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0))); 267 EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1))); 268 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000))); 269 270 auto in1 = c.input(1); 271 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 272 EXPECT_TRUE(c.RankKnown(in1)); 273 EXPECT_EQ(3, c.Rank(in1)); 274 auto d = c.Dim(in1, 0); 275 EXPECT_EQ(1, c.Value(d)); 276 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3))); 277 EXPECT_TRUE(c.ValueKnown(d)); 278 EXPECT_EQ("1", c.DebugString(d)); 279 d = c.Dim(in1, 1); 280 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d)); 281 EXPECT_FALSE(c.ValueKnown(d)); 282 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2))); 283 EXPECT_EQ("?", c.DebugString(d)); 284 d = c.Dim(in1, 2); 285 EXPECT_EQ(3, c.Value(d)); 286 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1))); 287 EXPECT_TRUE(c.ValueKnown(d)); 288 EXPECT_EQ("3", c.DebugString(d)); 289 290 auto in2 = c.input(2); 291 EXPECT_EQ("[]", c.DebugString(in2)); 292 EXPECT_TRUE(c.RankKnown(in2)); 293 EXPECT_EQ(0, c.Rank(in2)); 294 } 295 296 TEST_F(ShapeInferenceTest, NumElements) { 297 NodeDef def; 298 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 299 {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}); 300 301 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0)))); 302 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1)))); 303 304 // Different handles (not the same unknown value). 305 EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1)))); 306 307 EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2)))); 308 } 309 310 TEST_F(ShapeInferenceTest, WithRank) { 311 NodeDef def; 312 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 313 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 314 315 auto in0 = c.input(0); 316 auto in1 = c.input(1); 317 ShapeHandle s1; 318 ShapeHandle s2; 319 320 // WithRank on a shape with unknown dimensionality always succeeds. 321 EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok()); 322 EXPECT_EQ("[?]", c.DebugString(s1)); 323 324 EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok()); 325 EXPECT_EQ("[?,?]", c.DebugString(s2)); 326 EXPECT_FALSE(SameHandle(s1, s2)); 327 EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1))); 328 329 EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok()); 330 EXPECT_EQ("[?]", c.DebugString(s2)); 331 EXPECT_FALSE(SameHandle(s1, s2)); 332 333 EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok()); 334 EXPECT_EQ("[]", c.DebugString(s1)); 335 336 // WithRank on shape with known dimensionality. 337 s1 = in1; 338 EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3", 339 c.WithRank(in1, 2, &s1).ToString()); 340 EXPECT_FALSE(IsSet(s1)); 341 EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok()); 342 EXPECT_TRUE(SameHandle(s1, in1)); 343 344 // Inputs are unchanged. 345 EXPECT_EQ("?", c.DebugString(in0)); 346 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 347 } 348 349 TEST_F(ShapeInferenceTest, WithRankAtMost) { 350 NodeDef def; 351 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 352 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 353 354 auto in0 = c.input(0); 355 auto in1 = c.input(1); 356 ShapeHandle s1; 357 ShapeHandle s2; 358 359 // WithRankAtMost on a shape with unknown dimensionality always succeeds. 360 EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok()); 361 EXPECT_EQ("?", c.DebugString(s1)); 362 EXPECT_TRUE(SameHandle(in0, s1)); 363 364 EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok()); 365 EXPECT_EQ("?", c.DebugString(s2)); 366 EXPECT_TRUE(SameHandle(s1, s2)); 367 368 // WithRankAtMost on shape with known dimensionality. 369 s1 = in1; 370 EXPECT_TRUE(str_util::StrContains( 371 c.WithRankAtMost(in1, 2, &s1).ToString(), 372 "Invalid argument: Shape must be at most rank 2 but is rank 3")); 373 374 EXPECT_FALSE(IsSet(s1)); 375 EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok()); 376 EXPECT_TRUE(SameHandle(s1, in1)); 377 EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok()); 378 EXPECT_TRUE(SameHandle(s1, in1)); 379 EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok()); 380 EXPECT_TRUE(SameHandle(s1, in1)); 381 382 // Inputs are unchanged. 383 EXPECT_EQ("?", c.DebugString(in0)); 384 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 385 } 386 387 TEST_F(ShapeInferenceTest, WithRankAtLeast) { 388 NodeDef def; 389 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 390 {Unknown(), S({1, -1, 3})}, {}, {}, {}); 391 392 auto in0 = c.input(0); 393 auto in1 = c.input(1); 394 ShapeHandle s1; 395 ShapeHandle s2; 396 397 // WithRankAtLeast on a shape with unknown dimensionality always succeeds. 398 EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok()); 399 EXPECT_EQ("?", c.DebugString(s1)); 400 EXPECT_TRUE(SameHandle(in0, s1)); 401 402 EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok()); 403 EXPECT_EQ("?", c.DebugString(s2)); 404 EXPECT_TRUE(SameHandle(s1, s2)); 405 406 // WithRankAtLeast on shape with known dimensionality. 407 s1 = in1; 408 EXPECT_TRUE(str_util::StrContains( 409 c.WithRankAtLeast(in1, 4, &s1).ToString(), 410 "Invalid argument: Shape must be at least rank 4 but is rank 3")); 411 412 EXPECT_FALSE(IsSet(s1)); 413 EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok()); 414 EXPECT_TRUE(SameHandle(s1, in1)); 415 EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok()); 416 EXPECT_TRUE(SameHandle(s1, in1)); 417 EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok()); 418 EXPECT_TRUE(SameHandle(s1, in1)); 419 420 // Inputs are unchanged. 421 EXPECT_EQ("?", c.DebugString(in0)); 422 EXPECT_EQ("[1,?,3]", c.DebugString(in1)); 423 } 424 425 TEST_F(ShapeInferenceTest, WithValue) { 426 NodeDef def; 427 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}); 428 429 auto d0 = c.Dim(c.input(0), 0); 430 auto d1 = c.Dim(c.input(0), 1); 431 DimensionHandle out1; 432 DimensionHandle out2; 433 434 // WithValue on a dimension with unknown value always succeeds. 435 EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok()); 436 EXPECT_EQ(1, c.Value(out1)); 437 438 EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok()); 439 EXPECT_EQ(2, c.Value(out2)); 440 EXPECT_FALSE(SameHandle(out1, out2)); 441 EXPECT_FALSE(SameHandle(out1, d1)); 442 443 EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok()); 444 EXPECT_EQ(1, c.Value(out2)); 445 EXPECT_FALSE(SameHandle(out1, out2)); 446 447 // WithValue on dimension with known size. 448 out1 = d0; 449 450 EXPECT_TRUE( 451 str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(), 452 "Invalid argument: Dimension must be 0 but is 1")); 453 EXPECT_FALSE(IsSet(out1)); 454 out1 = d0; 455 EXPECT_TRUE( 456 str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(), 457 "Invalid argument: Dimension must be 2 but is 1")); 458 459 EXPECT_FALSE(IsSet(out1)); 460 EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok()); 461 EXPECT_TRUE(SameHandle(d0, out1)); 462 463 // Inputs are unchanged. 464 EXPECT_EQ("1", c.DebugString(d0)); 465 EXPECT_EQ("?", c.DebugString(d1)); 466 } 467 468 TEST_F(ShapeInferenceTest, MergeDim) { 469 NodeDef def; 470 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, 471 {}, {}, {}); 472 473 auto d2 = c.Dim(c.input(0), 0); 474 auto d_unknown = c.Dim(c.input(0), 1); 475 auto d2_b = c.Dim(c.input(0), 2); 476 auto d1 = c.Dim(c.input(0), 3); 477 auto d_unknown_b = c.Dim(c.input(0), 4); 478 DimensionHandle out; 479 480 // Merging anything with unknown returns the same pointer. 481 EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok()); 482 EXPECT_TRUE(SameHandle(d2, out)); 483 EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok()); 484 EXPECT_TRUE(SameHandle(d2, out)); 485 EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok()); 486 EXPECT_TRUE(SameHandle(d_unknown, out)); 487 488 auto merged_dims = c.MergedDims(); 489 ASSERT_EQ(3, merged_dims.size()); 490 EXPECT_TRUE(merged_dims[0].first.SameHandle(d2)); 491 EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown)); 492 EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown)); 493 EXPECT_TRUE(merged_dims[1].second.SameHandle(d2)); 494 EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown)); 495 EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b)); 496 497 // Merging with self is a no-op and returns self. 498 EXPECT_TRUE(c.Merge(d2, d2, &out).ok()); 499 EXPECT_TRUE(SameHandle(d2, out)); 500 EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok()); 501 EXPECT_TRUE(SameHandle(d_unknown, out)); 502 503 merged_dims = c.MergedDims(); 504 EXPECT_EQ(3, merged_dims.size()); 505 506 // Merging equal values is a no op and returns first one. 507 EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok()); 508 EXPECT_TRUE(SameHandle(d2, out)); 509 EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok()); 510 EXPECT_TRUE(SameHandle(d2_b, out)); 511 512 merged_dims = c.MergedDims(); 513 EXPECT_EQ(3, merged_dims.size()); 514 515 // Merging unequal values is an error. 516 EXPECT_TRUE(str_util::StrContains( 517 c.Merge(d2, d1, &out).ToString(), 518 "Invalid argument: Dimensions must be equal, but are 2 and 1")); 519 520 EXPECT_FALSE(IsSet(out)); 521 EXPECT_TRUE(str_util::StrContains( 522 c.Merge(d1, d2, &out).ToString(), 523 "Invalid argument: Dimensions must be equal, but are 1 and 2")); 524 525 EXPECT_FALSE(IsSet(out)); 526 527 merged_dims = c.MergedDims(); 528 EXPECT_EQ(3, merged_dims.size()); 529 } 530 531 TEST_F(ShapeInferenceTest, RelaxDim) { 532 NodeDef def; 533 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), 534 {S({2, InferenceContext::kUnknownDim, 2, 1, 535 InferenceContext::kUnknownDim})}, 536 {}, {}, {}); 537 538 auto d2 = c.Dim(c.input(0), 0); 539 auto d_unknown = c.Dim(c.input(0), 1); 540 auto d2_b = c.Dim(c.input(0), 2); 541 auto d1 = c.Dim(c.input(0), 3); 542 auto d_unknown_b = c.Dim(c.input(0), 4); 543 DimensionHandle out; 544 545 // Relaxing anything with unknown returns a new unknown or the existing 546 // unknown. 547 Relax(&c, d2, d_unknown, &out); 548 EXPECT_TRUE(SameHandle(d_unknown, out)); 549 EXPECT_FALSE(SameHandle(d_unknown_b, out)); 550 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 551 Relax(&c, d_unknown, d2, &out); 552 EXPECT_FALSE(SameHandle(d_unknown, out)); 553 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 554 Relax(&c, d_unknown, d_unknown_b, &out); 555 EXPECT_FALSE(SameHandle(d_unknown, out)); 556 EXPECT_TRUE(SameHandle(d_unknown_b, out)); 557 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 558 559 // Relaxing with self returns self. 560 Relax(&c, d2, d2, &out); 561 EXPECT_TRUE(SameHandle(d2, out)); 562 Relax(&c, d_unknown, d_unknown, &out); 563 EXPECT_TRUE(SameHandle(d_unknown, out)); 564 565 // Relaxing equal values returns first one. 566 Relax(&c, d2, d2_b, &out); 567 EXPECT_TRUE(SameHandle(d2, out)); 568 Relax(&c, d2_b, d2, &out); 569 EXPECT_TRUE(SameHandle(d2_b, out)); 570 571 // Relaxing unequal values returns a new unknown. 572 Relax(&c, d2, d1, &out); 573 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 574 Relax(&c, d1, d2, &out); 575 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); 576 } 577 578 TEST_F(ShapeInferenceTest, RelaxShape) { 579 NodeDef def; 580 InferenceContext c( 581 kVersion, &def, MakeOpDef(7, 2), 582 {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}), 583 S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})}, 584 {}, {}, {}); 585 586 auto s_unknown = c.input(0); 587 auto s_1_2 = c.input(1); 588 auto s_u_2 = c.input(2); 589 auto s_1_u = c.input(3); 590 auto s_1_3 = c.input(4); 591 auto s_unknown_b = c.input(5); 592 auto s_1 = c.input(6); 593 ShapeHandle out; 594 595 // Relaxing any shape with unknown returns a new unknown. 596 Relax(&c, s_unknown, s_1_2, &out); 597 EXPECT_FALSE(SameHandle(s_u_2, s_unknown)); 598 EXPECT_EQ("?", c.DebugString(out)); 599 Relax(&c, s_u_2, s_unknown, &out); 600 EXPECT_FALSE(SameHandle(s_u_2, out)); 601 EXPECT_EQ("?", c.DebugString(out)); 602 Relax(&c, s_unknown, s_unknown_b, &out); 603 EXPECT_FALSE(SameHandle(s_unknown, out)); 604 EXPECT_TRUE(SameHandle(s_unknown_b, out)); 605 EXPECT_EQ("?", c.DebugString(out)); 606 607 // Relaxing with self returns self. 608 Relax(&c, s_1_2, s_1_2, &out); 609 EXPECT_TRUE(SameHandle(out, s_1_2)); 610 611 // Relaxing where one of the inputs has less information. 612 out = ShapeHandle(); 613 Relax(&c, s_1_2, s_u_2, &out); 614 EXPECT_FALSE(SameHandle(s_u_2, out)); 615 EXPECT_EQ("[?,2]", c.DebugString(out)); 616 out = ShapeHandle(); 617 Relax(&c, s_u_2, s_1_2, &out); 618 EXPECT_FALSE(SameHandle(s_u_2, out)); 619 EXPECT_EQ("[?,2]", c.DebugString(out)); 620 621 // Relaxing where each input has one distinct unknown dimension. 622 Relax(&c, s_u_2, s_1_u, &out); 623 EXPECT_EQ("[?,?]", c.DebugString(out)); 624 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 625 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1))); 626 auto s_u1 = c.UnknownShapeOfRank(1); 627 auto s_u2 = c.UnknownShapeOfRank(1); 628 Relax(&c, s_u1, s_u2, &out); 629 EXPECT_FALSE(SameHandle(s_u1, out)); 630 631 // Relaxing with mismatched values in a dimension returns a shape with that 632 // dimension unknown. 633 out = s_unknown; 634 Relax(&c, s_u_2, s_1_3, &out); 635 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 636 EXPECT_EQ("[?,?]", c.DebugString(out)); 637 out = s_unknown; 638 Relax(&c, s_1_3, s_u_2, &out); 639 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); 640 EXPECT_EQ("[?,?]", c.DebugString(out)); 641 out = s_unknown; 642 643 // Relaxing with mismatched ranks returns a new unknown. 644 Relax(&c, s_1, s_1_2, &out); 645 EXPECT_EQ("?", c.DebugString(out)); 646 } 647 648 TEST_F(ShapeInferenceTest, MergeShape) { 649 NodeDef def; 650 InferenceContext c(kVersion, &def, MakeOpDef(7, 2), 651 {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), 652 Unknown(), S({1})}, 653 {}, {}, {}); 654 655 auto s_unknown = c.input(0); 656 auto s_1_2 = c.input(1); 657 auto s_u_2 = c.input(2); 658 auto s_1_u = c.input(3); 659 auto s_1_3 = c.input(4); 660 auto s_unknown_b = c.input(5); 661 auto s_1 = c.input(6); 662 ShapeHandle out; 663 664 // Merging any shape with unknown returns the shape. 665 EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok()); 666 EXPECT_TRUE(SameHandle(s_1_2, out)); 667 EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok()); 668 EXPECT_TRUE(SameHandle(s_u_2, out)); 669 EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok()); 670 EXPECT_TRUE(SameHandle(s_unknown, out)); 671 672 auto merged_shapes = c.MergedShapes(); 673 ASSERT_EQ(3, merged_shapes.size()); 674 EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown)); 675 EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2)); 676 EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2)); 677 EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown)); 678 EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown)); 679 EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b)); 680 681 // Merging with self returns self. 682 EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok()); 683 EXPECT_TRUE(SameHandle(out, s_1_2)); 684 685 merged_shapes = c.MergedShapes(); 686 EXPECT_EQ(3, merged_shapes.size()); 687 688 // Merging where one of the inputs is the right answer - return that input. 689 out = ShapeHandle(); 690 EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok()); 691 EXPECT_TRUE(SameHandle(s_1_2, out)); 692 out = ShapeHandle(); 693 EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok()); 694 EXPECT_TRUE(SameHandle(s_1_2, out)); 695 696 merged_shapes = c.MergedShapes(); 697 ASSERT_EQ(5, merged_shapes.size()); 698 EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2)); 699 EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2)); 700 EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2)); 701 EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2)); 702 703 // Merging where neither input is the right answer. 704 EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok()); 705 EXPECT_FALSE(SameHandle(out, s_u_2)); 706 EXPECT_FALSE(SameHandle(out, s_1_u)); 707 EXPECT_EQ("[1,2]", c.DebugString(out)); 708 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0))); 709 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1))); 710 711 merged_shapes = c.MergedShapes(); 712 ASSERT_EQ(7, merged_shapes.size()); 713 EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2)); 714 EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u)); 715 EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2)); 716 EXPECT_TRUE(merged_shapes[6].second.SameHandle(out)); 717 718 auto s_u1 = c.UnknownShapeOfRank(1); 719 auto s_u2 = c.UnknownShapeOfRank(1); 720 TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out)); 721 EXPECT_TRUE(SameHandle(s_u1, out)); 722 723 merged_shapes = c.MergedShapes(); 724 ASSERT_EQ(8, merged_shapes.size()); 725 EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1)); 726 EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2)); 727 728 // Incompatible merges give errors and set out to nullptr. 729 out = s_unknown; 730 EXPECT_TRUE(str_util::StrContains( 731 c.Merge(s_u_2, s_1_3, &out).ToString(), 732 "Invalid argument: Dimension 1 in both shapes must be equal, but " 733 "are 2 and 3")); 734 735 EXPECT_FALSE(IsSet(out)); 736 out = s_unknown; 737 EXPECT_TRUE(str_util::StrContains( 738 c.Merge(s_1_3, s_u_2, &out).ToString(), 739 "Invalid argument: Dimension 1 in both shapes must be equal, but " 740 "are 3 and 2")); 741 742 EXPECT_FALSE(IsSet(out)); 743 out = s_unknown; 744 EXPECT_TRUE(str_util::StrContains( 745 c.Merge(s_1, s_1_2, &out).ToString(), 746 "Invalid argument: Shapes must be equal rank, but are 1 and 2")); 747 748 EXPECT_FALSE(IsSet(out)); 749 750 merged_shapes = c.MergedShapes(); 751 EXPECT_EQ(8, merged_shapes.size()); 752 } 753 754 TEST_F(ShapeInferenceTest, MergePrefix) { 755 NodeDef def; 756 InferenceContext c(kVersion, &def, MakeOpDef(4, 2), 757 { 758 Unknown(), 759 S({-1, 2}), 760 S({1, -1, 3}), 761 S({2, 4}), 762 }, 763 {}, {}, {}); 764 765 auto s_unknown = c.input(0); 766 auto s_u_2 = c.input(1); 767 auto s_1_u_3 = c.input(2); 768 auto s_2_4 = c.input(3); 769 770 ShapeHandle s_out; 771 ShapeHandle s_prefix_out; 772 773 // Merging with unknown returns the inputs. 774 EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok()); 775 EXPECT_TRUE(SameHandle(s_out, s_unknown)); 776 EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2)); 777 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok()); 778 EXPECT_TRUE(SameHandle(s_out, s_1_u_3)); 779 EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown)); 780 781 EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok()); 782 EXPECT_FALSE(SameHandle(s_out, s_1_u_3)); 783 EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out)); 784 EXPECT_EQ("[1,2,3]", c.DebugString(s_out)); 785 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0))); 786 EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0))); 787 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1))); 788 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1))); 789 790 // Incompatible merges give errors and set outs to nullptr. 791 s_out = s_unknown; 792 s_prefix_out = s_unknown; 793 EXPECT_TRUE(str_util::StrContains( 794 c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(), 795 "Invalid argument: Dimensions must be equal, but are 1 and 2")); 796 797 EXPECT_FALSE(IsSet(s_out)); 798 EXPECT_FALSE(IsSet(s_prefix_out)); 799 800 s_out = s_unknown; 801 s_prefix_out = s_unknown; 802 EXPECT_TRUE(str_util::StrContains( 803 c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(), 804 "Invalid argument: Shape must be at least rank 3 but is rank 2")); 805 EXPECT_FALSE(IsSet(s_out)); 806 EXPECT_FALSE(IsSet(s_prefix_out)); 807 } 808 809 TEST_F(ShapeInferenceTest, Subshape) { 810 NodeDef def; 811 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), 812 {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}); 813 814 ShapeHandle unknown = c.input(1); 815 ShapeHandle out; 816 EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok()); 817 EXPECT_EQ("?", c.DebugString(out)); 818 EXPECT_TRUE(SameHandle(out, unknown)); 819 EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok()); 820 EXPECT_EQ("?", c.DebugString(out)); 821 EXPECT_FALSE(SameHandle(out, unknown)); 822 EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok()); 823 EXPECT_EQ("?", c.DebugString(out)); 824 EXPECT_FALSE(SameHandle(out, unknown)); 825 826 const int kFullRank = 5; 827 ShapeHandle out_arr[4]; 828 auto in0 = c.input(0); 829 EXPECT_TRUE(c.Subshape(in0, 0, &out).ok()); 830 EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out)); 831 EXPECT_TRUE(SameHandle(out, in0)); 832 EXPECT_EQ(kFullRank, c.Rank(out)); 833 for (int start = 0; start <= kFullRank + 1; ++start) { 834 for (int end = start; end <= kFullRank + 1; ++end) { 835 // Get subshapes using different start and end values that give the same 836 // range. 837 const int neg_start = 838 start >= kFullRank ? kFullRank : (start - kFullRank); 839 const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank); 840 ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok()); 841 ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok()); 842 ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok()); 843 ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok()); 844 845 // Verify all computed subshapes. 846 for (int arr_idx = 0; arr_idx < 4; ++arr_idx) { 847 out = out_arr[arr_idx]; 848 ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start), 849 c.Rank(out)) 850 << "start: " << start << " end: " << end << " arr_idx: " << arr_idx 851 << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out); 852 for (int d = 0; d < c.Rank(out); ++d) { 853 EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d))) 854 << "arr_idx: " << arr_idx; 855 } 856 } 857 } 858 } 859 860 // Errors. 861 out = unknown; 862 EXPECT_TRUE(str_util::StrContains( 863 c.Subshape(in0, 6, -3, &out).ToString(), 864 "Invalid argument: Subshape must have computed start <= end, but is 5 " 865 "and 2 (computed from start 6 and end -3 over shape with rank 5)")); 866 EXPECT_FALSE(IsSet(out)); 867 out = unknown; 868 EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(), 869 "Invalid argument: Subshape start out of " 870 "bounds: -50, for shape with rank 5")); 871 872 EXPECT_FALSE(IsSet(out)); 873 out = unknown; 874 EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(), 875 "Invalid argument: Subshape end out of " 876 "bounds: -50, for shape with rank 5")); 877 878 EXPECT_FALSE(IsSet(out)); 879 } 880 881 TEST_F(ShapeInferenceTest, Concatenate) { 882 NodeDef def; 883 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), 884 {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}); 885 886 auto in0 = c.input(0); 887 auto in1 = c.input(1); 888 ShapeHandle unknown = c.input(2); 889 ShapeHandle out; 890 EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok()); 891 EXPECT_EQ("?", c.DebugString(out)); 892 EXPECT_FALSE(SameHandle(out, unknown)); 893 EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok()); 894 EXPECT_EQ("?", c.DebugString(out)); 895 EXPECT_FALSE(SameHandle(out, unknown)); 896 897 EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok()); 898 EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out)); 899 int out_i = 0; 900 for (int i = 0; i < c.Rank(in0); ++i, ++out_i) { 901 EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i))); 902 } 903 for (int i = 0; i < c.Rank(in1); ++i, ++out_i) { 904 EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i))); 905 } 906 } 907 908 TEST_F(ShapeInferenceTest, ReplaceDim) { 909 NodeDef def; 910 InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, 911 {}, {}, {}); 912 913 auto in = c.input(0); 914 auto unknown = c.input(1); 915 916 ShapeHandle replaced; 917 EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok()); 918 EXPECT_EQ("[2,2,3]", c.DebugString(replaced)); 919 EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok()); 920 EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); 921 EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok()); 922 EXPECT_EQ("[1,3,3]", c.DebugString(replaced)); 923 EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok()); 924 EXPECT_EQ("?", c.DebugString(replaced)); 925 926 // Negative indexing. 927 EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok()); 928 EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); 929 EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok()); 930 EXPECT_EQ("?", c.DebugString(replaced)); 931 932 // out of range indexing. 933 EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok()); 934 EXPECT_FALSE(IsSet(replaced)); 935 replaced = in; 936 EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok()); 937 EXPECT_FALSE(IsSet(replaced)); 938 } 939 940 TEST_F(ShapeInferenceTest, MakeShape) { 941 NodeDef def; 942 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, 943 {}, {}); 944 945 std::vector<DimensionHandle> dims; 946 auto in0 = c.input(0); 947 const int rank = c.Rank(in0); 948 dims.reserve(rank); 949 for (int i = 0; i < rank; ++i) { 950 dims.push_back(c.Dim(in0, rank - i - 1)); 951 } 952 953 auto s = c.MakeShape(dims); 954 EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s)); 955 EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1))); 956 957 auto s2 = c.MakeShape(dims); 958 EXPECT_FALSE(SameHandle(s, s2)); 959 EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1))); 960 961 auto s3 = c.MakeShape({1, 2, dims[2]}); 962 EXPECT_FALSE(SameHandle(s, s3)); 963 EXPECT_EQ("[1,2,3]", c.DebugString(s3)); 964 } 965 966 TEST_F(ShapeInferenceTest, UnknownShape) { 967 NodeDef def; 968 std::vector<ShapeHandle> empty; 969 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 970 971 auto u0 = c.UnknownShape(); 972 auto u1 = c.UnknownShape(); 973 EXPECT_EQ("?", c.DebugString(u0)); 974 EXPECT_EQ("?", c.DebugString(u1)); 975 EXPECT_FALSE(SameHandle(u0, u1)); 976 } 977 978 TEST_F(ShapeInferenceTest, KnownShapeToProto) { 979 NodeDef def; 980 std::vector<ShapeHandle> empty; 981 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 982 983 auto s = c.MakeShape({1, 2, 3}); 984 TensorShapeProto proto; 985 c.ShapeHandleToProto(s, &proto); 986 987 EXPECT_FALSE(proto.unknown_rank()); 988 EXPECT_EQ(3, proto.dim_size()); 989 EXPECT_EQ(1, proto.dim(0).size()); 990 } 991 992 TEST_F(ShapeInferenceTest, UnknownShapeToProto) { 993 NodeDef def; 994 std::vector<ShapeHandle> empty; 995 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 996 997 auto u0 = c.UnknownShape(); 998 TensorShapeProto proto; 999 c.ShapeHandleToProto(u0, &proto); 1000 1001 EXPECT_TRUE(proto.unknown_rank()); 1002 EXPECT_EQ(0, proto.dim_size()); 1003 } 1004 1005 TEST_F(ShapeInferenceTest, Scalar) { 1006 NodeDef def; 1007 std::vector<ShapeHandle> empty; 1008 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1009 1010 auto s0 = c.Scalar(); 1011 EXPECT_EQ("[]", c.DebugString(s0)); 1012 auto s1 = c.Scalar(); 1013 EXPECT_EQ("[]", c.DebugString(s1)); 1014 } 1015 1016 TEST_F(ShapeInferenceTest, Vector) { 1017 NodeDef def; 1018 std::vector<ShapeHandle> empty; 1019 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1020 1021 auto s0 = c.Vector(1); 1022 EXPECT_EQ("[1]", c.DebugString(s0)); 1023 auto s1 = c.Vector(InferenceContext::kUnknownDim); 1024 EXPECT_EQ("[?]", c.DebugString(s1)); 1025 1026 auto d1 = c.UnknownDim(); 1027 auto s2 = c.Vector(d1); 1028 EXPECT_EQ("[?]", c.DebugString(s2)); 1029 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1030 } 1031 1032 TEST_F(ShapeInferenceTest, Matrix) { 1033 NodeDef def; 1034 std::vector<ShapeHandle> empty; 1035 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1036 1037 auto s0 = c.Matrix(1, 2); 1038 EXPECT_EQ("[1,2]", c.DebugString(s0)); 1039 auto s1 = c.Matrix(0, InferenceContext::kUnknownDim); 1040 EXPECT_EQ("[0,?]", c.DebugString(s1)); 1041 1042 auto d1 = c.UnknownDim(); 1043 auto d2 = c.UnknownDim(); 1044 auto s2 = c.Matrix(d1, d2); 1045 EXPECT_EQ("[?,?]", c.DebugString(s2)); 1046 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1047 EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1))); 1048 1049 auto s3 = c.Matrix(d1, 100); 1050 EXPECT_EQ("[?,100]", c.DebugString(s3)); 1051 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); 1052 } 1053 1054 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { 1055 auto create = [&](Tensor* t) { 1056 NodeDef def; 1057 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, 1058 {}); 1059 ShapeHandle out; 1060 Status s = c.MakeShapeFromShapeTensor(0, &out); 1061 if (s.ok()) { 1062 return c.DebugString(out); 1063 } else { 1064 EXPECT_FALSE(IsSet(out)); 1065 return s.error_message(); 1066 } 1067 }; 1068 1069 Tensor t; 1070 EXPECT_EQ("?", create(nullptr)); 1071 1072 t = ::tensorflow::test::AsTensor<int32>({1, 2, 3}); 1073 EXPECT_EQ("[1,2,3]", create(&t)); 1074 1075 t = ::tensorflow::test::AsTensor<int64>({3, 2, 1}); 1076 EXPECT_EQ("[3,2,1]", create(&t)); 1077 1078 t = ::tensorflow::test::AsTensor<int64>({3, -1, 1}); 1079 EXPECT_EQ("[3,?,1]", create(&t)); 1080 1081 t = ::tensorflow::test::AsTensor<int64>({}); 1082 EXPECT_EQ("[]", create(&t)); 1083 1084 // Test negative scalar 1085 t = ::tensorflow::test::AsScalar<int32>(-1); 1086 EXPECT_EQ("?", create(&t)); 1087 1088 t = ::tensorflow::test::AsTensor<float>({1, 2, 3}); 1089 EXPECT_TRUE(str_util::StrContains( 1090 create(&t), "Input tensor must be int32 or int64, but was float")); 1091 1092 t = ::tensorflow::test::AsScalar<int32>(1); 1093 auto s_scalar = create(&t); 1094 EXPECT_TRUE(str_util::StrContains( 1095 s_scalar, 1096 "Input tensor must be rank 1, or if its rank 0 it must have value -1")) 1097 << s_scalar; 1098 1099 t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1}); 1100 auto s_matrix = create(&t); 1101 EXPECT_TRUE(str_util::StrContains( 1102 s_matrix, "Input tensor must be rank 1, but was rank 2")) 1103 << s_matrix; 1104 1105 // Test negative values for the dims. 1106 t = ::tensorflow::test::AsTensor<int64>({3, -2, 1}); 1107 EXPECT_TRUE(str_util::StrContains( 1108 create(&t), "Invalid value in tensor used for shape: -2")); 1109 1110 // Test negative values for the dims. 1111 t = ::tensorflow::test::AsTensor<int32>({3, -2, 1}); 1112 EXPECT_TRUE(str_util::StrContains( 1113 create(&t), "Invalid value in tensor used for shape: -2")); 1114 1115 // Test when the input shape is wrong. 1116 { 1117 NodeDef def; 1118 InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, 1119 {}, {}); 1120 ShapeHandle out; 1121 EXPECT_EQ("Shape must be rank 1 but is rank 2", 1122 c.MakeShapeFromShapeTensor(0, &out).error_message()); 1123 } 1124 } 1125 1126 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { 1127 NodeDef def; 1128 std::vector<ShapeHandle> empty; 1129 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1130 1131 // With an unknown rank. 1132 ShapeHandle out; 1133 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out)); 1134 EXPECT_EQ("?", c.DebugString(out)); 1135 1136 // With a known rank. 1137 TF_ASSERT_OK( 1138 c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out)); 1139 EXPECT_EQ("[0]", c.DebugString(out)); 1140 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape( 1141 PartialTensorShape({0, -1, 1000}), &out)); 1142 EXPECT_EQ("[0,?,1000]", c.DebugString(out)); 1143 } 1144 1145 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { 1146 NodeDef def; 1147 std::vector<ShapeHandle> empty; 1148 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1149 1150 ShapeHandle out; 1151 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out)); 1152 EXPECT_EQ("[]", c.DebugString(out)); 1153 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out)); 1154 EXPECT_EQ("[0]", c.DebugString(out)); 1155 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out)); 1156 EXPECT_EQ("[0,7,1000]", c.DebugString(out)); 1157 } 1158 1159 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { 1160 NodeDef def; 1161 std::vector<ShapeHandle> empty; 1162 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1163 TensorShapeProto proto; 1164 1165 // With a set unknown rank. 1166 ShapeHandle out; 1167 proto.set_unknown_rank(true); 1168 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1169 EXPECT_EQ("?", c.DebugString(out)); 1170 proto.add_dim()->set_size(0); 1171 EXPECT_TRUE(str_util::StrContains( 1172 c.MakeShapeFromShapeProto(proto, &out).error_message(), 1173 "An unknown shape must not have any dimensions set.")); 1174 EXPECT_FALSE(IsSet(out)); 1175 1176 // With known rank. 1177 proto.set_unknown_rank(false); 1178 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1179 EXPECT_EQ("[0]", c.DebugString(out)); 1180 proto.add_dim()->set_size(-1); 1181 proto.add_dim()->set_size(1000); 1182 EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); 1183 EXPECT_EQ("[0,?,1000]", c.DebugString(out)); 1184 1185 // With invalid dimension value. 1186 proto.add_dim()->set_size(-2); 1187 EXPECT_TRUE(str_util::StrContains( 1188 c.MakeShapeFromShapeProto(proto, &out).error_message(), 1189 "Shape [0,?,1000,-2] has dimensions with values below -1 " 1190 "(where -1 means unknown)")); 1191 1192 EXPECT_FALSE(IsSet(out)); 1193 } 1194 1195 TEST_F(ShapeInferenceTest, MakeDim) { 1196 NodeDef def; 1197 std::vector<ShapeHandle> empty; 1198 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1199 1200 auto d0 = c.MakeDim(1); 1201 auto d1 = c.MakeDim(1); 1202 auto d2 = c.MakeDim(2); 1203 EXPECT_EQ("1", c.DebugString(d0)); 1204 EXPECT_EQ("1", c.DebugString(d1)); 1205 EXPECT_FALSE(SameHandle(d0, d1)); 1206 EXPECT_EQ("2", c.DebugString(d2)); 1207 } 1208 1209 TEST_F(ShapeInferenceTest, UnknownDim) { 1210 NodeDef def; 1211 std::vector<ShapeHandle> empty; 1212 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1213 1214 auto d0 = c.UnknownDim(); 1215 auto d1 = c.UnknownDim(); 1216 EXPECT_EQ("?", c.DebugString(d0)); 1217 EXPECT_EQ("?", c.DebugString(d1)); 1218 EXPECT_FALSE(SameHandle(d0, d1)); 1219 } 1220 1221 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { 1222 NodeDef def; 1223 std::vector<ShapeHandle> empty; 1224 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1225 1226 auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); 1227 EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); 1228 1229 auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0); 1230 EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0)); 1231 } 1232 1233 TEST_F(ShapeInferenceTest, InputTensors) { 1234 const Tensor t1 = tensorflow::test::AsTensor<float>({10}); 1235 const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); 1236 NodeDef def; 1237 InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, 1238 {&t1, &t2}, {}, {}); 1239 1240 EXPECT_TRUE(c.input_tensor(0) == &t1); 1241 EXPECT_TRUE(c.input_tensor(1) == &t2); 1242 EXPECT_TRUE(c.input_tensor(2) == nullptr); 1243 } 1244 1245 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { 1246 Tensor t1 = tensorflow::test::AsScalar<int32>(20); 1247 Tensor t2 = tensorflow::test::AsScalar<int32>(-1); 1248 NodeDef def; 1249 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, 1250 {&t1, &t2}, {}, {}); 1251 1252 DimensionHandle d; 1253 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); 1254 EXPECT_EQ("20", c.DebugString(d)); 1255 1256 EXPECT_TRUE( 1257 str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(), 1258 "Dimension size, given by scalar input 1, must be " 1259 "non-negative but is -1")); 1260 1261 // Same tests, with int64 values. 1262 t1 = tensorflow::test::AsScalar<int64>(20); 1263 t2 = tensorflow::test::AsScalar<int64>(-1); 1264 EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); 1265 EXPECT_EQ("20", c.DebugString(d)); 1266 1267 EXPECT_TRUE( 1268 str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(), 1269 "Dimension size, given by scalar input 1, must be " 1270 "non-negative but is -1")); 1271 } 1272 1273 TEST_F(ShapeInferenceTest, GetAttr) { 1274 OpRegistrationData op_reg_data; 1275 op_reg_data.op_def = MakeOpDef(0, 2); 1276 NodeDef def; 1277 CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def) 1278 .Attr("foo", "bar") 1279 .Finalize(&def) 1280 .ok()); 1281 1282 std::vector<ShapeHandle> empty; 1283 InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}); 1284 string value; 1285 EXPECT_TRUE(c.GetAttr("foo", &value).ok()); 1286 EXPECT_EQ("bar", value); 1287 } 1288 1289 TEST_F(ShapeInferenceTest, Divide) { 1290 NodeDef def; 1291 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, 1292 {}, {}); 1293 1294 auto s = c.input(0); 1295 auto d_6 = c.Dim(s, 0); 1296 auto d_unknown = c.Dim(s, 1); 1297 auto d_1 = c.Dim(s, 2); 1298 auto d_2 = c.Dim(s, 3); 1299 auto d_0 = c.Dim(s, 4); 1300 bool evenly_divisible = true; 1301 1302 // Dividing unknown by non-1 gives new unknown. 1303 DimensionHandle out; 1304 EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok()); 1305 EXPECT_EQ("?", c.DebugString(out)); 1306 EXPECT_FALSE(SameHandle(out, d_unknown)); 1307 1308 // Dividing anything by 1 returns the input. 1309 EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok()); 1310 EXPECT_TRUE(SameHandle(out, d_unknown)); 1311 EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok()); 1312 EXPECT_TRUE(SameHandle(out, d_6)); 1313 EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok()); 1314 EXPECT_TRUE(SameHandle(out, d_unknown)); 1315 EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok()); 1316 EXPECT_TRUE(SameHandle(out, d_6)); 1317 1318 EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok()); 1319 EXPECT_EQ("3", c.DebugString(out)); 1320 EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok()); 1321 EXPECT_EQ("3", c.DebugString(out)); 1322 1323 EXPECT_TRUE(str_util::StrContains( 1324 c.Divide(d_6, 5, evenly_divisible, &out).error_message(), 1325 "Dimension size must be evenly divisible by 5 but is 6")); 1326 1327 EXPECT_TRUE(str_util::StrContains( 1328 c.Divide(d_6, 0, evenly_divisible, &out).error_message(), 1329 "Divisor must be positive but is 0")); 1330 EXPECT_TRUE(str_util::StrContains( 1331 c.Divide(d_6, d_0, evenly_divisible, &out).error_message(), 1332 "Divisor must be positive but is 0")); 1333 1334 EXPECT_TRUE(str_util::StrContains( 1335 c.Divide(d_6, -1, evenly_divisible, &out).error_message(), 1336 "Divisor must be positive but is -1")); 1337 1338 // Repeat error cases above with evenly_divisible=false. 1339 evenly_divisible = false; 1340 EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok()); 1341 EXPECT_EQ("1", c.DebugString(out)); 1342 1343 EXPECT_TRUE(str_util::StrContains( 1344 c.Divide(d_6, 0, evenly_divisible, &out).error_message(), 1345 "Divisor must be positive but is 0")); 1346 1347 EXPECT_TRUE(str_util::StrContains( 1348 c.Divide(d_6, -1, evenly_divisible, &out).error_message(), 1349 "Divisor must be positive but is -1")); 1350 } 1351 1352 TEST_F(ShapeInferenceTest, Add) { 1353 NodeDef def; 1354 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, 1355 {}); 1356 1357 auto s = c.input(0); 1358 auto d_6 = c.Dim(s, 0); 1359 auto d_unknown = c.Dim(s, 1); 1360 auto d_0 = c.Dim(s, 2); 1361 1362 // Adding non-zero to unknown gives new unknown. 1363 DimensionHandle out; 1364 EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok()); 1365 EXPECT_EQ("?", c.DebugString(out)); 1366 EXPECT_FALSE(SameHandle(out, d_unknown)); 1367 1368 // Adding 0 to anything gives input. 1369 EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok()); 1370 EXPECT_TRUE(SameHandle(out, d_unknown)); 1371 EXPECT_TRUE(c.Add(d_6, 0, &out).ok()); 1372 EXPECT_TRUE(SameHandle(out, d_6)); 1373 1374 // Adding dimension with value 0 to anything gives input. 1375 EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok()); 1376 EXPECT_TRUE(SameHandle(out, d_unknown)); 1377 EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok()); 1378 EXPECT_TRUE(SameHandle(out, d_6)); 1379 1380 // Test addition. 1381 EXPECT_TRUE(c.Add(d_6, 2, &out).ok()); 1382 EXPECT_EQ("8", c.DebugString(out)); 1383 EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok()); 1384 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); 1385 1386 // Test addition using dimension as second value. 1387 EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok()); 1388 EXPECT_EQ("8", c.DebugString(out)); 1389 EXPECT_TRUE( 1390 c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok()); 1391 EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); 1392 EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok()); 1393 EXPECT_EQ("?", c.DebugString(out)); 1394 EXPECT_TRUE(c.Add(d_0, d_6, &out).ok()); 1395 EXPECT_TRUE(SameHandle(out, d_6)); 1396 1397 EXPECT_TRUE(str_util::StrContains( 1398 c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(), 1399 "Dimension size overflow from adding 6 and 9223372036854775802")); 1400 } 1401 1402 TEST_F(ShapeInferenceTest, Subtract) { 1403 NodeDef def; 1404 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, 1405 {}, {}); 1406 1407 auto s = c.input(0); 1408 auto d_6 = c.Dim(s, 0); 1409 auto d_unknown = c.Dim(s, 1); 1410 auto d_0 = c.Dim(s, 2); 1411 auto d_5 = c.Dim(s, 3); 1412 1413 // Subtracting non-zero from unknown gives new unknown. 1414 DimensionHandle out; 1415 EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok()); 1416 EXPECT_EQ("?", c.DebugString(out)); 1417 EXPECT_FALSE(SameHandle(out, d_unknown)); 1418 1419 // Subtracting 0 from anything gives input. 1420 EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok()); 1421 EXPECT_TRUE(SameHandle(out, d_unknown)); 1422 EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok()); 1423 EXPECT_TRUE(SameHandle(out, d_6)); 1424 1425 // Subtracting dimension with value 0 from anything gives input. 1426 EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok()); 1427 EXPECT_TRUE(SameHandle(out, d_unknown)); 1428 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok()); 1429 EXPECT_TRUE(SameHandle(out, d_6)); 1430 1431 // Test subtraction. 1432 EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok()); 1433 EXPECT_EQ("4", c.DebugString(out)); 1434 EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok()); 1435 EXPECT_EQ("0", c.DebugString(out)); 1436 1437 // Test subtraction using dimension as second value. 1438 EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok()); 1439 EXPECT_EQ("4", c.DebugString(out)); 1440 EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok()); 1441 EXPECT_EQ("1", c.DebugString(out)); 1442 EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok()); 1443 EXPECT_EQ("?", c.DebugString(out)); 1444 EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok()); 1445 EXPECT_TRUE(SameHandle(out, d_6)); 1446 1447 EXPECT_TRUE(str_util::StrContains( 1448 c.Subtract(d_5, d_6, &out).error_message(), 1449 "Negative dimension size caused by subtracting 6 from 5")); 1450 } 1451 1452 TEST_F(ShapeInferenceTest, Multiply) { 1453 NodeDef def; 1454 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, 1455 {}, {}); 1456 1457 auto s = c.input(0); 1458 auto d_6 = c.Dim(s, 0); 1459 auto d_unknown = c.Dim(s, 1); 1460 auto d_0 = c.Dim(s, 2); 1461 auto d_1 = c.Dim(s, 3); 1462 1463 // Multiplying non-zero to unknown gives new unknown. 1464 DimensionHandle out; 1465 EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok()); 1466 EXPECT_EQ("?", c.DebugString(out)); 1467 1468 // Multiplying 0 to anything gives 0. 1469 EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok()); 1470 EXPECT_EQ("0", c.DebugString(out)); 1471 EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok()); 1472 EXPECT_EQ("0", c.DebugString(out)); 1473 EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok()); 1474 EXPECT_EQ("0", c.DebugString(out)); 1475 1476 // Multiplying 1 to anything gives the original. 1477 // (unknown -> unknown) 1478 EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); 1479 EXPECT_TRUE(SameHandle(d_unknown, out)); 1480 EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok()); 1481 EXPECT_TRUE(SameHandle(d_unknown, out)); 1482 EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok()); 1483 EXPECT_TRUE(SameHandle(d_unknown, out)); 1484 // (known -> known) 1485 EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok()); 1486 EXPECT_TRUE(SameHandle(d_6, out)); 1487 EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok()); 1488 EXPECT_TRUE(SameHandle(d_6, out)); 1489 EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok()); 1490 EXPECT_TRUE(SameHandle(d_6, out)); 1491 1492 // Test multiplication. 1493 EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok()); 1494 EXPECT_EQ("12", c.DebugString(out)); 1495 EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok()); 1496 EXPECT_EQ("36", c.DebugString(out)); 1497 1498 // Test multiplication using dimension as second value. 1499 EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok()); 1500 EXPECT_EQ("12", c.DebugString(out)); 1501 EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok()); 1502 EXPECT_EQ("?", c.DebugString(out)); 1503 } 1504 1505 TEST_F(ShapeInferenceTest, FullyDefined) { 1506 NodeDef def; 1507 std::vector<ShapeHandle> empty; 1508 InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); 1509 1510 // No rank or missing dimension information should return false. 1511 EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); 1512 EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim()))); 1513 1514 // Return true if all information exists. 1515 EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2)))); 1516 EXPECT_TRUE(c.FullyDefined(c.Scalar())); 1517 } 1518 1519 TEST_F(ShapeInferenceTest, Min) { 1520 NodeDef def; 1521 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, 1522 {}, {}); 1523 1524 auto s = c.input(0); 1525 auto d_1 = c.Dim(s, 0); 1526 auto d_2 = c.Dim(s, 1); 1527 auto d_unknown = c.Dim(s, 2); 1528 auto d_0 = c.Dim(s, 3); 1529 1530 // Minimum involving zero and unknown returns zero. 1531 DimensionHandle out; 1532 EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok()); 1533 EXPECT_TRUE(SameHandle(d_0, out)); 1534 EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok()); 1535 EXPECT_TRUE(SameHandle(d_0, out)); 1536 EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok()); 1537 EXPECT_EQ("0", c.DebugString(out)); 1538 EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok()); 1539 EXPECT_EQ("0", c.DebugString(out)); 1540 1541 // Minimum involving unknowns and non-zeros gives new unknown. 1542 EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok()); 1543 EXPECT_EQ("?", c.DebugString(out)); 1544 EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok()); 1545 EXPECT_EQ("?", c.DebugString(out)); 1546 EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok()); 1547 EXPECT_EQ("?", c.DebugString(out)); 1548 1549 // Minimum with constant second arg. 1550 EXPECT_TRUE(c.Min(d_1, 1, &out).ok()); 1551 EXPECT_TRUE(SameHandle(d_1, out)); 1552 EXPECT_TRUE(c.Min(d_1, 3, &out).ok()); 1553 EXPECT_TRUE(SameHandle(d_1, out)); 1554 EXPECT_TRUE(c.Min(d_2, 1, &out).ok()); 1555 EXPECT_EQ("1", c.DebugString(out)); 1556 1557 // Minimum with two dimensions. 1558 EXPECT_TRUE(c.Min(d_1, d_1, &out).ok()); 1559 EXPECT_TRUE(SameHandle(d_1, out)); 1560 EXPECT_TRUE(c.Min(d_1, d_2, &out).ok()); 1561 EXPECT_TRUE(SameHandle(d_1, out)); 1562 EXPECT_TRUE(c.Min(d_2, d_1, &out).ok()); 1563 EXPECT_TRUE(SameHandle(d_1, out)); 1564 EXPECT_TRUE(c.Min(d_2, d_2, &out).ok()); 1565 EXPECT_TRUE(SameHandle(d_2, out)); 1566 } 1567 1568 TEST_F(ShapeInferenceTest, Max) { 1569 NodeDef def; 1570 InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, 1571 {}); 1572 1573 auto s = c.input(0); 1574 auto d_1 = c.Dim(s, 0); 1575 auto d_2 = c.Dim(s, 1); 1576 auto d_unknown = c.Dim(s, 2); 1577 1578 // Maximum involving unknowns gives new unknown. 1579 DimensionHandle out; 1580 EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok()); 1581 EXPECT_EQ("?", c.DebugString(out)); 1582 EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok()); 1583 EXPECT_EQ("?", c.DebugString(out)); 1584 EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok()); 1585 EXPECT_EQ("?", c.DebugString(out)); 1586 1587 // Maximum with constant second arg. 1588 EXPECT_TRUE(c.Max(d_1, 1, &out).ok()); 1589 EXPECT_TRUE(SameHandle(d_1, out)); 1590 EXPECT_TRUE(c.Max(d_2, 1, &out).ok()); 1591 EXPECT_TRUE(SameHandle(d_2, out)); 1592 EXPECT_TRUE(c.Max(d_2, 3, &out).ok()); 1593 EXPECT_EQ("3", c.DebugString(out)); 1594 1595 // Maximum with two dimensions. 1596 EXPECT_TRUE(c.Max(d_1, d_1, &out).ok()); 1597 EXPECT_TRUE(SameHandle(d_1, out)); 1598 EXPECT_TRUE(c.Max(d_1, d_2, &out).ok()); 1599 EXPECT_TRUE(SameHandle(d_2, out)); 1600 EXPECT_TRUE(c.Max(d_2, d_1, &out).ok()); 1601 EXPECT_TRUE(SameHandle(d_2, out)); 1602 EXPECT_TRUE(c.Max(d_2, d_2, &out).ok()); 1603 EXPECT_TRUE(SameHandle(d_2, out)); 1604 } 1605 1606 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { 1607 NodeDef def; 1608 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, 1609 {}); 1610 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { 1611 ShapeHandle s; 1612 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); 1613 return s; 1614 }; 1615 auto get_shapes_and_types_from_context = [&](int idx) { 1616 if (input_not_output) { 1617 return c.input_handle_shapes_and_types(idx); 1618 } else { 1619 return c.output_handle_shapes_and_types(idx); 1620 } 1621 }; 1622 auto merge_shapes_and_types_to_context = 1623 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) { 1624 if (input_not_output) { 1625 return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types); 1626 } else { 1627 return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types); 1628 } 1629 }; 1630 1631 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); 1632 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); 1633 1634 // First merge will take the input completely. 1635 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT}, 1636 {c.UnknownShape(), DT_INVALID}, 1637 {make_shape({4, 3, 2, 1}), DT_INT32}}; 1638 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1639 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); 1640 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0); 1641 ASSERT_EQ(3, v.size()); 1642 for (int i = 0; i < v.size(); ++i) { 1643 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1644 EXPECT_EQ(t[i].dtype, v[i].dtype); 1645 } 1646 1647 // Merge that fails because wrong number of values passed. 1648 // Fails, and no changes made. 1649 ASSERT_FALSE(merge_shapes_and_types_to_context( 1650 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}})); 1651 v = *get_shapes_and_types_from_context(0); 1652 ASSERT_EQ(3, v.size()); 1653 for (int i = 0; i < v.size(); ++i) { 1654 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1655 EXPECT_EQ(t[i].dtype, v[i].dtype); 1656 } 1657 1658 // Only difference is in a mismatched shape. That is ignored, 1659 // and there are no other changes, so nothing is done. 1660 // 1661 // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to 1662 // return an error (separate error from 'refined' output)? 1663 auto t2 = t; 1664 t2[2].shape = make_shape({4, 3, 4, 1}); 1665 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); 1666 v = *get_shapes_and_types_from_context(0); 1667 ASSERT_EQ(3, v.size()); 1668 for (int i = 0; i < v.size(); ++i) { 1669 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1670 EXPECT_EQ(t[i].dtype, v[i].dtype); 1671 } 1672 1673 // Only difference is in a mismatched dtype, but that cannot be 1674 // updated unless original dtype is DT_INVALID. 1675 t2 = t; 1676 t2[2].dtype = DT_FLOAT; 1677 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); 1678 v = *get_shapes_and_types_from_context(0); 1679 ASSERT_EQ(3, v.size()); 1680 for (int i = 0; i < v.size(); ++i) { 1681 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1682 EXPECT_EQ(t[i].dtype, v[i].dtype); 1683 } 1684 1685 // Difference is mergeable (new shape). 1686 t[1].shape = make_shape({1, 10}); 1687 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1688 v = *get_shapes_and_types_from_context(0); 1689 ASSERT_EQ(3, v.size()); 1690 for (int i = 0; i < v.size(); ++i) { 1691 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1692 EXPECT_EQ(t[i].dtype, v[i].dtype); 1693 } 1694 1695 // Difference is mergeable (new type). 1696 t[1].dtype = DT_DOUBLE; 1697 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); 1698 v = *get_shapes_and_types_from_context(0); 1699 ASSERT_EQ(3, v.size()); 1700 for (int i = 0; i < v.size(); ++i) { 1701 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1702 EXPECT_EQ(t[i].dtype, v[i].dtype); 1703 } 1704 1705 // No difference. 1706 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t)); 1707 } 1708 1709 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) { 1710 TestMergeHandles(true /* input_not_output */); 1711 } 1712 1713 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) { 1714 TestMergeHandles(false /* input_not_output */); 1715 } 1716 1717 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { 1718 NodeDef def; 1719 InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, 1720 {}); 1721 auto make_shape = [&c](std::initializer_list<int64> dim_sizes) { 1722 ShapeHandle s; 1723 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); 1724 return s; 1725 }; 1726 auto get_shapes_and_types_from_context = [&](int idx) { 1727 if (input_not_output) { 1728 return c.input_handle_shapes_and_types(idx); 1729 } else { 1730 return c.output_handle_shapes_and_types(idx); 1731 } 1732 }; 1733 auto relax_shapes_and_types_to_context = 1734 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) { 1735 if (input_not_output) { 1736 return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types); 1737 } else { 1738 return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types); 1739 } 1740 }; 1741 1742 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); 1743 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); 1744 1745 // First relax will take the input completely. 1746 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT}, 1747 {c.UnknownShape(), DT_INVALID}, 1748 {make_shape({4, 3, 2, 1}), DT_INT32}}; 1749 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1750 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); 1751 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0); 1752 ASSERT_EQ(3, v.size()); 1753 for (int i = 0; i < v.size(); ++i) { 1754 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1755 EXPECT_EQ(t[i].dtype, v[i].dtype); 1756 } 1757 1758 // Relax that fails because wrong number of values passed. 1759 // Fails, and no changes made. 1760 ASSERT_FALSE(relax_shapes_and_types_to_context( 1761 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}})); 1762 v = *get_shapes_and_types_from_context(0); 1763 ASSERT_EQ(3, v.size()); 1764 for (int i = 0; i < v.size(); ++i) { 1765 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; 1766 EXPECT_EQ(t[i].dtype, v[i].dtype); 1767 } 1768 1769 // Only difference is in a mismatched shape. This should replace 1770 // the mismatched dimension with an UnknownDim. 1771 auto t2 = t; 1772 t2[2].shape = make_shape({4, 3, 4, 1}); 1773 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2)); 1774 v = *get_shapes_and_types_from_context(0); 1775 EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape)); 1776 for (int i = 0; i < v.size(); ++i) { 1777 EXPECT_EQ(t[i].dtype, v[i].dtype); 1778 } 1779 1780 // Only difference is in a mismatched dtype, but that cannot be 1781 // updated unless original dtype is DT_INVALID. 1782 t2 = t; 1783 t2[2].dtype = DT_FLOAT; 1784 ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2)); 1785 v = *get_shapes_and_types_from_context(0); 1786 ASSERT_EQ(3, v.size()); 1787 for (int i = 0; i < v.size(); ++i) { 1788 EXPECT_EQ(t[i].dtype, v[i].dtype); 1789 } 1790 1791 // Difference is a new shape, which will result in a new UnknownShape. 1792 t[1].shape = make_shape({1, 10}); 1793 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1794 v = *get_shapes_and_types_from_context(0); 1795 ASSERT_EQ(3, v.size()); 1796 EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape)); 1797 EXPECT_EQ("?", c.DebugString(v[1].shape)); 1798 for (int i = 0; i < v.size(); ++i) { 1799 EXPECT_EQ(t[i].dtype, v[i].dtype); 1800 } 1801 1802 // Difference is relaxable (new type). 1803 t[1].dtype = DT_DOUBLE; 1804 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); 1805 v = *get_shapes_and_types_from_context(0); 1806 EXPECT_EQ(t[1].dtype, v[1].dtype); 1807 } 1808 1809 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) { 1810 TestRelaxHandles(true /* input_not_output */); 1811 } 1812 1813 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) { 1814 TestRelaxHandles(false /* input_not_output */); 1815 } 1816 1817 } // namespace shape_inference 1818 } // namespace tensorflow 1819