Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/shape_inference.h"
     16 
     17 #include "tensorflow/core/framework/fake_input.h"
     18 #include "tensorflow/core/framework/node_def_builder.h"
     19 #include "tensorflow/core/framework/op_def_builder.h"
     20 #include "tensorflow/core/framework/tensor_shape.pb.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/framework/types.pb.h"
     23 #include "tensorflow/core/lib/core/status_test_util.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/platform/test.h"
     26 
     27 namespace tensorflow {
     28 namespace shape_inference {
     29 namespace {
     30 
     31 OpDef MakeOpDefWithLists() {
     32   OpRegistrationData op_reg_data;
     33   OpDefBuilder b("dummy");
     34   b.Input(strings::StrCat("input: N * float"));
     35   b.Output(strings::StrCat("output: N * float"));
     36   CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok());
     37   return op_reg_data.op_def;
     38 }
     39 
     40 PartialTensorShape S(std::initializer_list<int64> dims) {
     41   return PartialTensorShape(dims);
     42 }
     43 
     44 PartialTensorShape Unknown() { return PartialTensorShape(); }
     45 
     46 }  // namespace
     47 
     48 class ShapeInferenceTest : public ::testing::Test {
     49  protected:
     50   // These give access to private functions of DimensionHandle and ShapeHandle.
     51   bool SameHandle(DimensionHandle a, DimensionHandle b) {
     52     return a.SameHandle(b);
     53   }
     54   bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
     55   bool IsSet(DimensionHandle d) { return d.IsSet(); }
     56   bool IsSet(ShapeHandle s) { return s.IsSet(); }
     57   void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
     58              DimensionHandle* out) {
     59     c->Relax(d0, d1, out);
     60   }
     61   void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
     62              ShapeHandle* out) {
     63     c->Relax(s0, s1, out);
     64   }
     65   void TestMergeHandles(bool input_not_output);
     66   void TestRelaxHandles(bool input_not_output);
     67 
     68   static const int kVersion = 0;  // used for graph-def version.
     69 };
     70 
     71 TEST_F(ShapeInferenceTest, InputOutputByName) {
     72   // Setup test to contain an input tensor list of size 3.
     73   OpDef op_def = MakeOpDefWithLists();
     74   NodeDef def;
     75   auto s = NodeDefBuilder("dummy", &op_def)
     76                .Attr("N", 3)
     77                .Input(FakeInput(DT_FLOAT))
     78                .Finalize(&def);
     79   InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
     80                      {}, {}, {});
     81 
     82   EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
     83   EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
     84   EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
     85   // Test getters.
     86   std::vector<ShapeHandle> shapes;
     87   EXPECT_FALSE(c.input("nonexistent", &shapes).ok());
     88   TF_EXPECT_OK(c.input("input", &shapes));
     89   EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
     90   EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
     91   EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
     92 
     93   // Test setters.
     94   EXPECT_FALSE(c.set_output("nonexistent", shapes).ok());
     95   TF_EXPECT_OK(c.set_output("output", shapes));
     96   EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
     97   EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
     98   EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
     99 }
    100 
    101 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
    102   OpRegistrationData op_reg_data;
    103   OpDefBuilder b("dummy");
    104   for (int i = 0; i < num_inputs; ++i) {
    105     b.Input(strings::StrCat("i", i, ": float"));
    106   }
    107   for (int i = 0; i < num_outputs; ++i) {
    108     b.Output(strings::StrCat("o", i, ": float"));
    109   }
    110   CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
    111   return op_reg_data.op_def;
    112 }
    113 
    114 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
    115   NodeDef def;
    116   InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
    117   EXPECT_EQ(InferenceContext::kUnknownDim,
    118             c.Value(InferenceContext::kUnknownDim));
    119   EXPECT_EQ(1, c.Value(1));
    120 
    121 #ifndef NDEBUG
    122   // Only run death test if DCHECKS are enabled.
    123   EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
    124 #endif
    125 }
    126 
    127 TEST_F(ShapeInferenceTest, Run) {
    128   NodeDef def;
    129   def.set_name("foo");
    130   def.set_op("foo_op");
    131   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
    132   TF_ASSERT_OK(c.construction_status());
    133 
    134   {
    135     auto fn = [](InferenceContext* c) {
    136       ShapeHandle h;
    137       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
    138       c->set_output(0, c->input(0));
    139       c->set_output(1, c->input(0));
    140       return Status::OK();
    141     };
    142     TF_ASSERT_OK(c.Run(fn));
    143   }
    144 
    145   {
    146     auto fn = [](InferenceContext* c) {
    147       ShapeHandle h;
    148       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
    149       c->set_output(0, c->input(0));
    150       c->set_output(1, c->input(0));
    151       return Status::OK();
    152     };
    153     Status s = c.Run(fn);
    154     // Extra error message is attached when Run fails.
    155     EXPECT_TRUE(StringPiece(s.ToString())
    156                     .contains("Shape must be at most rank 0 but "
    157                               "is rank 1 for 'foo' (op: "
    158                               "'foo_op')"))
    159         << s;
    160   }
    161 }
    162 
    163 // Tests different context data added when Run returns error.
    164 TEST_F(ShapeInferenceTest, AttachContext) {
    165   NodeDef def;
    166   def.set_name("foo");
    167   def.set_op("foo_op");
    168   // Error when no constant tensors were requested.
    169   {
    170     InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
    171                        {});
    172     TF_ASSERT_OK(c.construction_status());
    173     auto fn = [](InferenceContext* c) {
    174       ShapeHandle h;
    175       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
    176       c->set_output(0, c->input(0));
    177       return Status::OK();
    178     };
    179     EXPECT_EQ(
    180         "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
    181         "'foo' (op: 'foo_op') with input shapes: [1,2,3].",
    182         c.Run(fn).ToString());
    183   }
    184 
    185   // Error when a constant tensor value was requested.
    186   {
    187     Tensor input_t =
    188         ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
    189     InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
    190                        {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
    191     TF_ASSERT_OK(c.construction_status());
    192     auto fn = [](InferenceContext* c) {
    193       c->input_tensor(0);  // get this one, but it's null - won't be in error.
    194       c->input_tensor(1);  // get this one, will now be in error.
    195       ShapeHandle h;
    196       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
    197       c->set_output(0, c->input(0));
    198       return Status::OK();
    199     };
    200     EXPECT_EQ(
    201         "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
    202         "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with "
    203         "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.",
    204         c.Run(fn).ToString());
    205   }
    206 
    207   // Error when a constant tensor value as shape was requested, but no partial
    208   // shapes provided.
    209   {
    210     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
    211     InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
    212                        {nullptr, &input_t}, {}, {});
    213     TF_ASSERT_OK(c.construction_status());
    214     auto fn = [](InferenceContext* c) {
    215       ShapeHandle s;
    216       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    217       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
    218       ShapeHandle h;
    219       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
    220       c->set_output(0, c->input(0));
    221       return Status::OK();
    222     };
    223     EXPECT_EQ(
    224         "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
    225         "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
    226         "input tensors: input[1] = <1 2 3 4 5>.",
    227         c.Run(fn).ToString());
    228   }
    229 
    230   // Error when a constant tensor value as shape was requested, and a partial
    231   // shape was provided.
    232   {
    233     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
    234     InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
    235                        {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
    236     TF_ASSERT_OK(c.construction_status());
    237     auto fn = [](InferenceContext* c) {
    238       ShapeHandle s;
    239       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    240       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
    241       ShapeHandle h;
    242       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
    243       c->set_output(0, c->input(0));
    244       return Status::OK();
    245     };
    246     EXPECT_EQ(
    247         "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
    248         "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
    249         "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
    250         "as partial shapes: input[0] = [10,?,5].",
    251         c.Run(fn).ToString());
    252   }
    253 }
    254 
    255 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
    256   NodeDef def;
    257   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
    258                      {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
    259   EXPECT_EQ(3, c.num_inputs());
    260   EXPECT_EQ(2, c.num_outputs());
    261 
    262   auto in0 = c.input(0);
    263   EXPECT_EQ("?", c.DebugString(in0));
    264   EXPECT_FALSE(c.RankKnown(in0));
    265   EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
    266   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
    267   EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
    268   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
    269 
    270   auto in1 = c.input(1);
    271   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
    272   EXPECT_TRUE(c.RankKnown(in1));
    273   EXPECT_EQ(3, c.Rank(in1));
    274   auto d = c.Dim(in1, 0);
    275   EXPECT_EQ(1, c.Value(d));
    276   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
    277   EXPECT_TRUE(c.ValueKnown(d));
    278   EXPECT_EQ("1", c.DebugString(d));
    279   d = c.Dim(in1, 1);
    280   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
    281   EXPECT_FALSE(c.ValueKnown(d));
    282   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
    283   EXPECT_EQ("?", c.DebugString(d));
    284   d = c.Dim(in1, 2);
    285   EXPECT_EQ(3, c.Value(d));
    286   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
    287   EXPECT_TRUE(c.ValueKnown(d));
    288   EXPECT_EQ("3", c.DebugString(d));
    289 
    290   auto in2 = c.input(2);
    291   EXPECT_EQ("[]", c.DebugString(in2));
    292   EXPECT_TRUE(c.RankKnown(in2));
    293   EXPECT_EQ(0, c.Rank(in2));
    294 }
    295 
    296 TEST_F(ShapeInferenceTest, NumElements) {
    297   NodeDef def;
    298   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
    299                      {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
    300 
    301   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
    302   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
    303 
    304   // Different handles (not the same unknown value).
    305   EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
    306 
    307   EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
    308 }
    309 
    310 TEST_F(ShapeInferenceTest, WithRank) {
    311   NodeDef def;
    312   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
    313                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
    314 
    315   auto in0 = c.input(0);
    316   auto in1 = c.input(1);
    317   ShapeHandle s1;
    318   ShapeHandle s2;
    319 
    320   // WithRank on a shape with unknown dimensionality always succeeds.
    321   EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok());
    322   EXPECT_EQ("[?]", c.DebugString(s1));
    323 
    324   EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok());
    325   EXPECT_EQ("[?,?]", c.DebugString(s2));
    326   EXPECT_FALSE(SameHandle(s1, s2));
    327   EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
    328 
    329   EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok());
    330   EXPECT_EQ("[?]", c.DebugString(s2));
    331   EXPECT_FALSE(SameHandle(s1, s2));
    332 
    333   EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok());
    334   EXPECT_EQ("[]", c.DebugString(s1));
    335 
    336   // WithRank on shape with known dimensionality.
    337   s1 = in1;
    338   EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3",
    339             c.WithRank(in1, 2, &s1).ToString());
    340   EXPECT_FALSE(IsSet(s1));
    341   EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok());
    342   EXPECT_TRUE(SameHandle(s1, in1));
    343 
    344   // Inputs are unchanged.
    345   EXPECT_EQ("?", c.DebugString(in0));
    346   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
    347 }
    348 
    349 TEST_F(ShapeInferenceTest, WithRankAtMost) {
    350   NodeDef def;
    351   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
    352                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
    353 
    354   auto in0 = c.input(0);
    355   auto in1 = c.input(1);
    356   ShapeHandle s1;
    357   ShapeHandle s2;
    358 
    359   // WithRankAtMost on a shape with unknown dimensionality always succeeds.
    360   EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
    361   EXPECT_EQ("?", c.DebugString(s1));
    362   EXPECT_TRUE(SameHandle(in0, s1));
    363 
    364   EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
    365   EXPECT_EQ("?", c.DebugString(s2));
    366   EXPECT_TRUE(SameHandle(s1, s2));
    367 
    368   // WithRankAtMost on shape with known dimensionality.
    369   s1 = in1;
    370   EXPECT_TRUE(
    371       StringPiece(c.WithRankAtMost(in1, 2, &s1).ToString())
    372           .contains(
    373               "Invalid argument: Shape must be at most rank 2 but is rank 3"));
    374 
    375   EXPECT_FALSE(IsSet(s1));
    376   EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
    377   EXPECT_TRUE(SameHandle(s1, in1));
    378   EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok());
    379   EXPECT_TRUE(SameHandle(s1, in1));
    380   EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok());
    381   EXPECT_TRUE(SameHandle(s1, in1));
    382 
    383   // Inputs are unchanged.
    384   EXPECT_EQ("?", c.DebugString(in0));
    385   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
    386 }
    387 
    388 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
    389   NodeDef def;
    390   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
    391                      {Unknown(), S({1, -1, 3})}, {}, {}, {});
    392 
    393   auto in0 = c.input(0);
    394   auto in1 = c.input(1);
    395   ShapeHandle s1;
    396   ShapeHandle s2;
    397 
    398   // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
    399   EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
    400   EXPECT_EQ("?", c.DebugString(s1));
    401   EXPECT_TRUE(SameHandle(in0, s1));
    402 
    403   EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
    404   EXPECT_EQ("?", c.DebugString(s2));
    405   EXPECT_TRUE(SameHandle(s1, s2));
    406 
    407   // WithRankAtLeast on shape with known dimensionality.
    408   s1 = in1;
    409   EXPECT_TRUE(
    410       StringPiece(c.WithRankAtLeast(in1, 4, &s1).ToString())
    411           .contains(
    412               "Invalid argument: Shape must be at least rank 4 but is rank 3"));
    413 
    414   EXPECT_FALSE(IsSet(s1));
    415   EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
    416   EXPECT_TRUE(SameHandle(s1, in1));
    417   EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok());
    418   EXPECT_TRUE(SameHandle(s1, in1));
    419   EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok());
    420   EXPECT_TRUE(SameHandle(s1, in1));
    421 
    422   // Inputs are unchanged.
    423   EXPECT_EQ("?", c.DebugString(in0));
    424   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
    425 }
    426 
    427 TEST_F(ShapeInferenceTest, WithValue) {
    428   NodeDef def;
    429   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
    430 
    431   auto d0 = c.Dim(c.input(0), 0);
    432   auto d1 = c.Dim(c.input(0), 1);
    433   DimensionHandle out1;
    434   DimensionHandle out2;
    435 
    436   // WithValue on a dimension with unknown value always succeeds.
    437   EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok());
    438   EXPECT_EQ(1, c.Value(out1));
    439 
    440   EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok());
    441   EXPECT_EQ(2, c.Value(out2));
    442   EXPECT_FALSE(SameHandle(out1, out2));
    443   EXPECT_FALSE(SameHandle(out1, d1));
    444 
    445   EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok());
    446   EXPECT_EQ(1, c.Value(out2));
    447   EXPECT_FALSE(SameHandle(out1, out2));
    448 
    449   // WithValue on dimension with known size.
    450   out1 = d0;
    451 
    452   EXPECT_TRUE(StringPiece(c.WithValue(d0, 0, &out1).ToString())
    453                   .contains("Invalid argument: Dimension must be 0 but is 1"));
    454   EXPECT_FALSE(IsSet(out1));
    455   out1 = d0;
    456   EXPECT_TRUE(StringPiece(c.WithValue(d0, 2, &out1).ToString())
    457                   .contains("Invalid argument: Dimension must be 2 but is 1"));
    458 
    459   EXPECT_FALSE(IsSet(out1));
    460   EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
    461   EXPECT_TRUE(SameHandle(d0, out1));
    462 
    463   // Inputs are unchanged.
    464   EXPECT_EQ("1", c.DebugString(d0));
    465   EXPECT_EQ("?", c.DebugString(d1));
    466 }
    467 
    468 TEST_F(ShapeInferenceTest, MergeDim) {
    469   NodeDef def;
    470   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
    471                      {}, {}, {});
    472 
    473   auto d2 = c.Dim(c.input(0), 0);
    474   auto d_unknown = c.Dim(c.input(0), 1);
    475   auto d2_b = c.Dim(c.input(0), 2);
    476   auto d1 = c.Dim(c.input(0), 3);
    477   auto d_unknown_b = c.Dim(c.input(0), 4);
    478   DimensionHandle out;
    479 
    480   // Merging anything with unknown returns the same pointer.
    481   EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok());
    482   EXPECT_TRUE(SameHandle(d2, out));
    483   EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok());
    484   EXPECT_TRUE(SameHandle(d2, out));
    485   EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok());
    486   EXPECT_TRUE(SameHandle(d_unknown, out));
    487 
    488   auto merged_dims = c.MergedDims();
    489   ASSERT_EQ(3, merged_dims.size());
    490   EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
    491   EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
    492   EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
    493   EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
    494   EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
    495   EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
    496 
    497   // Merging with self is a no-op and returns self.
    498   EXPECT_TRUE(c.Merge(d2, d2, &out).ok());
    499   EXPECT_TRUE(SameHandle(d2, out));
    500   EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok());
    501   EXPECT_TRUE(SameHandle(d_unknown, out));
    502 
    503   merged_dims = c.MergedDims();
    504   EXPECT_EQ(3, merged_dims.size());
    505 
    506   // Merging equal values is a no op and returns first one.
    507   EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok());
    508   EXPECT_TRUE(SameHandle(d2, out));
    509   EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok());
    510   EXPECT_TRUE(SameHandle(d2_b, out));
    511 
    512   merged_dims = c.MergedDims();
    513   EXPECT_EQ(3, merged_dims.size());
    514 
    515   // Merging unequal values is an error.
    516   EXPECT_TRUE(
    517       StringPiece(c.Merge(d2, d1, &out).ToString())
    518           .contains(
    519               "Invalid argument: Dimensions must be equal, but are 2 and 1"));
    520 
    521   EXPECT_FALSE(IsSet(out));
    522   EXPECT_TRUE(
    523       StringPiece(c.Merge(d1, d2, &out).ToString())
    524           .contains(
    525               "Invalid argument: Dimensions must be equal, but are 1 and 2"));
    526 
    527   EXPECT_FALSE(IsSet(out));
    528 
    529   merged_dims = c.MergedDims();
    530   EXPECT_EQ(3, merged_dims.size());
    531 }
    532 
    533 TEST_F(ShapeInferenceTest, RelaxDim) {
    534   NodeDef def;
    535   InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
    536                      {S({2, InferenceContext::kUnknownDim, 2, 1,
    537                          InferenceContext::kUnknownDim})},
    538                      {}, {}, {});
    539 
    540   auto d2 = c.Dim(c.input(0), 0);
    541   auto d_unknown = c.Dim(c.input(0), 1);
    542   auto d2_b = c.Dim(c.input(0), 2);
    543   auto d1 = c.Dim(c.input(0), 3);
    544   auto d_unknown_b = c.Dim(c.input(0), 4);
    545   DimensionHandle out;
    546 
    547   // Relaxing anything with unknown returns a new unknown or the existing
    548   // unknown.
    549   Relax(&c, d2, d_unknown, &out);
    550   EXPECT_TRUE(SameHandle(d_unknown, out));
    551   EXPECT_FALSE(SameHandle(d_unknown_b, out));
    552   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
    553   Relax(&c, d_unknown, d2, &out);
    554   EXPECT_FALSE(SameHandle(d_unknown, out));
    555   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
    556   Relax(&c, d_unknown, d_unknown_b, &out);
    557   EXPECT_FALSE(SameHandle(d_unknown, out));
    558   EXPECT_TRUE(SameHandle(d_unknown_b, out));
    559   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
    560 
    561   // Relaxing with self returns self.
    562   Relax(&c, d2, d2, &out);
    563   EXPECT_TRUE(SameHandle(d2, out));
    564   Relax(&c, d_unknown, d_unknown, &out);
    565   EXPECT_TRUE(SameHandle(d_unknown, out));
    566 
    567   // Relaxing equal values returns first one.
    568   Relax(&c, d2, d2_b, &out);
    569   EXPECT_TRUE(SameHandle(d2, out));
    570   Relax(&c, d2_b, d2, &out);
    571   EXPECT_TRUE(SameHandle(d2_b, out));
    572 
    573   // Relaxing unequal values returns a new unknown.
    574   Relax(&c, d2, d1, &out);
    575   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
    576   Relax(&c, d1, d2, &out);
    577   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
    578 }
    579 
    580 TEST_F(ShapeInferenceTest, RelaxShape) {
    581   NodeDef def;
    582   InferenceContext c(
    583       kVersion, &def, MakeOpDef(7, 2),
    584       {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
    585        S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
    586       {}, {}, {});
    587 
    588   auto s_unknown = c.input(0);
    589   auto s_1_2 = c.input(1);
    590   auto s_u_2 = c.input(2);
    591   auto s_1_u = c.input(3);
    592   auto s_1_3 = c.input(4);
    593   auto s_unknown_b = c.input(5);
    594   auto s_1 = c.input(6);
    595   ShapeHandle out;
    596 
    597   // Relaxing any shape with unknown returns a new unknown.
    598   Relax(&c, s_unknown, s_1_2, &out);
    599   EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
    600   EXPECT_EQ("?", c.DebugString(out));
    601   Relax(&c, s_u_2, s_unknown, &out);
    602   EXPECT_FALSE(SameHandle(s_u_2, out));
    603   EXPECT_EQ("?", c.DebugString(out));
    604   Relax(&c, s_unknown, s_unknown_b, &out);
    605   EXPECT_FALSE(SameHandle(s_unknown, out));
    606   EXPECT_TRUE(SameHandle(s_unknown_b, out));
    607   EXPECT_EQ("?", c.DebugString(out));
    608 
    609   // Relaxing with self returns self.
    610   Relax(&c, s_1_2, s_1_2, &out);
    611   EXPECT_TRUE(SameHandle(out, s_1_2));
    612 
    613   // Relaxing where one of the inputs has less information.
    614   out = ShapeHandle();
    615   Relax(&c, s_1_2, s_u_2, &out);
    616   EXPECT_FALSE(SameHandle(s_u_2, out));
    617   EXPECT_EQ("[?,2]", c.DebugString(out));
    618   out = ShapeHandle();
    619   Relax(&c, s_u_2, s_1_2, &out);
    620   EXPECT_FALSE(SameHandle(s_u_2, out));
    621   EXPECT_EQ("[?,2]", c.DebugString(out));
    622 
    623   // Relaxing where each input has one distinct unknown dimension.
    624   Relax(&c, s_u_2, s_1_u, &out);
    625   EXPECT_EQ("[?,?]", c.DebugString(out));
    626   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
    627   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
    628   auto s_u1 = c.UnknownShapeOfRank(1);
    629   auto s_u2 = c.UnknownShapeOfRank(1);
    630   Relax(&c, s_u1, s_u2, &out);
    631   EXPECT_FALSE(SameHandle(s_u1, out));
    632 
    633   // Relaxing with mismatched values in a dimension returns a shape with that
    634   // dimension unknown.
    635   out = s_unknown;
    636   Relax(&c, s_u_2, s_1_3, &out);
    637   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
    638   EXPECT_EQ("[?,?]", c.DebugString(out));
    639   out = s_unknown;
    640   Relax(&c, s_1_3, s_u_2, &out);
    641   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
    642   EXPECT_EQ("[?,?]", c.DebugString(out));
    643   out = s_unknown;
    644 
    645   // Relaxing with mismatched ranks returns a new unknown.
    646   Relax(&c, s_1, s_1_2, &out);
    647   EXPECT_EQ("?", c.DebugString(out));
    648 }
    649 
    650 TEST_F(ShapeInferenceTest, MergeShape) {
    651   NodeDef def;
    652   InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
    653                      {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
    654                       Unknown(), S({1})},
    655                      {}, {}, {});
    656 
    657   auto s_unknown = c.input(0);
    658   auto s_1_2 = c.input(1);
    659   auto s_u_2 = c.input(2);
    660   auto s_1_u = c.input(3);
    661   auto s_1_3 = c.input(4);
    662   auto s_unknown_b = c.input(5);
    663   auto s_1 = c.input(6);
    664   ShapeHandle out;
    665 
    666   // Merging any shape with unknown returns the shape.
    667   EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok());
    668   EXPECT_TRUE(SameHandle(s_1_2, out));
    669   EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok());
    670   EXPECT_TRUE(SameHandle(s_u_2, out));
    671   EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok());
    672   EXPECT_TRUE(SameHandle(s_unknown, out));
    673 
    674   auto merged_shapes = c.MergedShapes();
    675   ASSERT_EQ(3, merged_shapes.size());
    676   EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
    677   EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
    678   EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
    679   EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
    680   EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
    681   EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
    682 
    683   // Merging with self returns self.
    684   EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok());
    685   EXPECT_TRUE(SameHandle(out, s_1_2));
    686 
    687   merged_shapes = c.MergedShapes();
    688   EXPECT_EQ(3, merged_shapes.size());
    689 
    690   // Merging where one of the inputs is the right answer - return that input.
    691   out = ShapeHandle();
    692   EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok());
    693   EXPECT_TRUE(SameHandle(s_1_2, out));
    694   out = ShapeHandle();
    695   EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok());
    696   EXPECT_TRUE(SameHandle(s_1_2, out));
    697 
    698   merged_shapes = c.MergedShapes();
    699   ASSERT_EQ(5, merged_shapes.size());
    700   EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
    701   EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
    702   EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
    703   EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
    704 
    705   // Merging where neither input is the right answer.
    706   EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok());
    707   EXPECT_FALSE(SameHandle(out, s_u_2));
    708   EXPECT_FALSE(SameHandle(out, s_1_u));
    709   EXPECT_EQ("[1,2]", c.DebugString(out));
    710   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
    711   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
    712 
    713   merged_shapes = c.MergedShapes();
    714   ASSERT_EQ(7, merged_shapes.size());
    715   EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
    716   EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
    717   EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
    718   EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
    719 
    720   auto s_u1 = c.UnknownShapeOfRank(1);
    721   auto s_u2 = c.UnknownShapeOfRank(1);
    722   TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
    723   EXPECT_TRUE(SameHandle(s_u1, out));
    724 
    725   merged_shapes = c.MergedShapes();
    726   ASSERT_EQ(8, merged_shapes.size());
    727   EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
    728   EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
    729 
    730   // Incompatible merges give errors and set out to nullptr.
    731   out = s_unknown;
    732   EXPECT_TRUE(
    733       StringPiece(c.Merge(s_u_2, s_1_3, &out).ToString())
    734           .contains(
    735               "Invalid argument: Dimension 1 in both shapes must be equal, but "
    736               "are 2 and 3"));
    737 
    738   EXPECT_FALSE(IsSet(out));
    739   out = s_unknown;
    740   EXPECT_TRUE(
    741       StringPiece(c.Merge(s_1_3, s_u_2, &out).ToString())
    742           .contains(
    743               "Invalid argument: Dimension 1 in both shapes must be equal, but "
    744               "are 3 and 2"));
    745 
    746   EXPECT_FALSE(IsSet(out));
    747   out = s_unknown;
    748   EXPECT_TRUE(
    749       StringPiece(c.Merge(s_1, s_1_2, &out).ToString())
    750           .contains(
    751               "Invalid argument: Shapes must be equal rank, but are 1 and 2"));
    752 
    753   EXPECT_FALSE(IsSet(out));
    754 
    755   merged_shapes = c.MergedShapes();
    756   EXPECT_EQ(8, merged_shapes.size());
    757 }
    758 
    759 TEST_F(ShapeInferenceTest, MergePrefix) {
    760   NodeDef def;
    761   InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
    762                      {
    763                          Unknown(),
    764                          S({-1, 2}),
    765                          S({1, -1, 3}),
    766                          S({2, 4}),
    767                      },
    768                      {}, {}, {});
    769 
    770   auto s_unknown = c.input(0);
    771   auto s_u_2 = c.input(1);
    772   auto s_1_u_3 = c.input(2);
    773   auto s_2_4 = c.input(3);
    774 
    775   ShapeHandle s_out;
    776   ShapeHandle s_prefix_out;
    777 
    778   // Merging with unknown returns the inputs.
    779   EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok());
    780   EXPECT_TRUE(SameHandle(s_out, s_unknown));
    781   EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
    782   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok());
    783   EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
    784   EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
    785 
    786   EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok());
    787   EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
    788   EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
    789   EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
    790   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
    791   EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
    792   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
    793   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
    794 
    795   // Incompatible merges give errors and set outs to nullptr.
    796   s_out = s_unknown;
    797   s_prefix_out = s_unknown;
    798   EXPECT_TRUE(
    799       StringPiece(
    800           c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString())
    801           .contains(
    802               "Invalid argument: Dimensions must be equal, but are 1 and 2"));
    803 
    804   EXPECT_FALSE(IsSet(s_out));
    805   EXPECT_FALSE(IsSet(s_prefix_out));
    806 
    807   s_out = s_unknown;
    808   s_prefix_out = s_unknown;
    809   EXPECT_TRUE(
    810       StringPiece(
    811           c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString())
    812           .contains(
    813               "Invalid argument: Shape must be at least rank 3 but is rank 2"));
    814   EXPECT_FALSE(IsSet(s_out));
    815   EXPECT_FALSE(IsSet(s_prefix_out));
    816 }
    817 
    818 TEST_F(ShapeInferenceTest, Subshape) {
    819   NodeDef def;
    820   InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
    821                      {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
    822 
    823   ShapeHandle unknown = c.input(1);
    824   ShapeHandle out;
    825   EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok());
    826   EXPECT_EQ("?", c.DebugString(out));
    827   EXPECT_TRUE(SameHandle(out, unknown));
    828   EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok());
    829   EXPECT_EQ("?", c.DebugString(out));
    830   EXPECT_FALSE(SameHandle(out, unknown));
    831   EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok());
    832   EXPECT_EQ("?", c.DebugString(out));
    833   EXPECT_FALSE(SameHandle(out, unknown));
    834 
    835   const int kFullRank = 5;
    836   ShapeHandle out_arr[4];
    837   auto in0 = c.input(0);
    838   EXPECT_TRUE(c.Subshape(in0, 0, &out).ok());
    839   EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
    840   EXPECT_TRUE(SameHandle(out, in0));
    841   EXPECT_EQ(kFullRank, c.Rank(out));
    842   for (int start = 0; start <= kFullRank + 1; ++start) {
    843     for (int end = start; end <= kFullRank + 1; ++end) {
    844       // Get subshapes using different start and end values that give the same
    845       // range.
    846       const int neg_start =
    847           start >= kFullRank ? kFullRank : (start - kFullRank);
    848       const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
    849       ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok());
    850       ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok());
    851       ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok());
    852       ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok());
    853 
    854       // Verify all computed subshapes.
    855       for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
    856         out = out_arr[arr_idx];
    857         ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
    858                   c.Rank(out))
    859             << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
    860             << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
    861         for (int d = 0; d < c.Rank(out); ++d) {
    862           EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
    863               << "arr_idx: " << arr_idx;
    864         }
    865       }
    866     }
    867   }
    868 
    869   // Errors.
    870   out = unknown;
    871   EXPECT_TRUE(StringPiece(c.Subshape(in0, 6, -3, &out).ToString())
    872                   .contains("Invalid argument: Subshape must have computed "
    873                             "start <= end, but is 5 "
    874                             "and 2 (computed from start 6 and end -3 over "
    875                             "shape with rank 5)"));
    876   EXPECT_FALSE(IsSet(out));
    877   out = unknown;
    878   EXPECT_TRUE(StringPiece(c.Subshape(in0, -50, 100, &out).ToString())
    879                   .contains("Invalid argument: Subshape start out of "
    880                             "bounds: -50, for shape with "
    881                             "rank 5"));
    882 
    883   EXPECT_FALSE(IsSet(out));
    884   out = unknown;
    885   EXPECT_TRUE(StringPiece(c.Subshape(in0, 0, -50, &out).ToString())
    886                   .contains("Invalid argument: Subshape end out of bounds: "
    887                             "-50, for shape with rank "
    888                             "5"));
    889 
    890   EXPECT_FALSE(IsSet(out));
    891 }
    892 
    893 TEST_F(ShapeInferenceTest, Concatenate) {
    894   NodeDef def;
    895   InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
    896                      {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
    897 
    898   auto in0 = c.input(0);
    899   auto in1 = c.input(1);
    900   ShapeHandle unknown = c.input(2);
    901   ShapeHandle out;
    902   EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok());
    903   EXPECT_EQ("?", c.DebugString(out));
    904   EXPECT_FALSE(SameHandle(out, unknown));
    905   EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok());
    906   EXPECT_EQ("?", c.DebugString(out));
    907   EXPECT_FALSE(SameHandle(out, unknown));
    908 
    909   EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok());
    910   EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
    911   int out_i = 0;
    912   for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
    913     EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
    914   }
    915   for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
    916     EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
    917   }
    918 }
    919 
    920 TEST_F(ShapeInferenceTest, ReplaceDim) {
    921   NodeDef def;
    922   InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
    923                      {}, {}, {});
    924 
    925   auto in = c.input(0);
    926   auto unknown = c.input(1);
    927 
    928   ShapeHandle replaced;
    929   EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
    930   EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
    931   EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
    932   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
    933   EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
    934   EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
    935   EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
    936   EXPECT_EQ("?", c.DebugString(replaced));
    937 
    938   // Negative indexing.
    939   EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
    940   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
    941   EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
    942   EXPECT_EQ("?", c.DebugString(replaced));
    943 
    944   // out of range indexing.
    945   EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
    946   EXPECT_FALSE(IsSet(replaced));
    947   replaced = in;
    948   EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
    949   EXPECT_FALSE(IsSet(replaced));
    950 }
    951 
    952 TEST_F(ShapeInferenceTest, MakeShape) {
    953   NodeDef def;
    954   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
    955                      {}, {});
    956 
    957   std::vector<DimensionHandle> dims;
    958   auto in0 = c.input(0);
    959   const int rank = c.Rank(in0);
    960   dims.reserve(rank);
    961   for (int i = 0; i < rank; ++i) {
    962     dims.push_back(c.Dim(in0, rank - i - 1));
    963   }
    964 
    965   auto s = c.MakeShape(dims);
    966   EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
    967   EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
    968 
    969   auto s2 = c.MakeShape(dims);
    970   EXPECT_FALSE(SameHandle(s, s2));
    971   EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
    972 
    973   auto s3 = c.MakeShape({1, 2, dims[2]});
    974   EXPECT_FALSE(SameHandle(s, s3));
    975   EXPECT_EQ("[1,2,3]", c.DebugString(s3));
    976 }
    977 
    978 TEST_F(ShapeInferenceTest, UnknownShape) {
    979   NodeDef def;
    980   std::vector<ShapeHandle> empty;
    981   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
    982 
    983   auto u0 = c.UnknownShape();
    984   auto u1 = c.UnknownShape();
    985   EXPECT_EQ("?", c.DebugString(u0));
    986   EXPECT_EQ("?", c.DebugString(u1));
    987   EXPECT_FALSE(SameHandle(u0, u1));
    988 }
    989 
    990 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
    991   NodeDef def;
    992   std::vector<ShapeHandle> empty;
    993   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
    994 
    995   auto s = c.MakeShape({1, 2, 3});
    996   TensorShapeProto proto;
    997   c.ShapeHandleToProto(s, &proto);
    998 
    999   EXPECT_FALSE(proto.unknown_rank());
   1000   EXPECT_EQ(3, proto.dim_size());
   1001   EXPECT_EQ(1, proto.dim(0).size());
   1002 }
   1003 
   1004 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
   1005   NodeDef def;
   1006   std::vector<ShapeHandle> empty;
   1007   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1008 
   1009   auto u0 = c.UnknownShape();
   1010   TensorShapeProto proto;
   1011   c.ShapeHandleToProto(u0, &proto);
   1012 
   1013   EXPECT_TRUE(proto.unknown_rank());
   1014   EXPECT_EQ(0, proto.dim_size());
   1015 }
   1016 
   1017 TEST_F(ShapeInferenceTest, Scalar) {
   1018   NodeDef def;
   1019   std::vector<ShapeHandle> empty;
   1020   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1021 
   1022   auto s0 = c.Scalar();
   1023   EXPECT_EQ("[]", c.DebugString(s0));
   1024   auto s1 = c.Scalar();
   1025   EXPECT_EQ("[]", c.DebugString(s1));
   1026 }
   1027 
   1028 TEST_F(ShapeInferenceTest, Vector) {
   1029   NodeDef def;
   1030   std::vector<ShapeHandle> empty;
   1031   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1032 
   1033   auto s0 = c.Vector(1);
   1034   EXPECT_EQ("[1]", c.DebugString(s0));
   1035   auto s1 = c.Vector(InferenceContext::kUnknownDim);
   1036   EXPECT_EQ("[?]", c.DebugString(s1));
   1037 
   1038   auto d1 = c.UnknownDim();
   1039   auto s2 = c.Vector(d1);
   1040   EXPECT_EQ("[?]", c.DebugString(s2));
   1041   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
   1042 }
   1043 
   1044 TEST_F(ShapeInferenceTest, Matrix) {
   1045   NodeDef def;
   1046   std::vector<ShapeHandle> empty;
   1047   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1048 
   1049   auto s0 = c.Matrix(1, 2);
   1050   EXPECT_EQ("[1,2]", c.DebugString(s0));
   1051   auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
   1052   EXPECT_EQ("[0,?]", c.DebugString(s1));
   1053 
   1054   auto d1 = c.UnknownDim();
   1055   auto d2 = c.UnknownDim();
   1056   auto s2 = c.Matrix(d1, d2);
   1057   EXPECT_EQ("[?,?]", c.DebugString(s2));
   1058   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
   1059   EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
   1060 
   1061   auto s3 = c.Matrix(d1, 100);
   1062   EXPECT_EQ("[?,100]", c.DebugString(s3));
   1063   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
   1064 }
   1065 
   1066 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
   1067   auto create = [&](Tensor* t) {
   1068     NodeDef def;
   1069     InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
   1070                        {});
   1071     ShapeHandle out;
   1072     Status s = c.MakeShapeFromShapeTensor(0, &out);
   1073     if (s.ok()) {
   1074       return c.DebugString(out);
   1075     } else {
   1076       EXPECT_FALSE(IsSet(out));
   1077       return s.error_message();
   1078     }
   1079   };
   1080 
   1081   Tensor t;
   1082   EXPECT_EQ("?", create(nullptr));
   1083 
   1084   t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
   1085   EXPECT_EQ("[1,2,3]", create(&t));
   1086 
   1087   t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
   1088   EXPECT_EQ("[3,2,1]", create(&t));
   1089 
   1090   t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
   1091   EXPECT_EQ("[3,?,1]", create(&t));
   1092 
   1093   t = ::tensorflow::test::AsTensor<int64>({});
   1094   EXPECT_EQ("[]", create(&t));
   1095 
   1096   t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
   1097   EXPECT_TRUE(
   1098       StringPiece(create(&t))
   1099           .contains("Input tensor must be int32 or int64, but was float"));
   1100 
   1101   t = ::tensorflow::test::AsScalar<int32>(1);
   1102   EXPECT_TRUE(StringPiece(create(&t))
   1103                   .contains("Input tensor must be rank 1, but was rank 0"));
   1104 
   1105   t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
   1106   EXPECT_TRUE(StringPiece(create(&t))
   1107                   .contains("Input tensor must be rank 1, but was rank 2"));
   1108 
   1109   // Test negative values for the dims.
   1110   t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
   1111   EXPECT_TRUE(StringPiece(create(&t))
   1112                   .contains("Invalid value in tensor used for shape: -2"));
   1113 
   1114   // Test negative values for the dims.
   1115   t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
   1116   EXPECT_TRUE(StringPiece(create(&t))
   1117                   .contains("Invalid value in tensor used for shape: -2"));
   1118 
   1119   // Test when the input shape is wrong.
   1120   {
   1121     NodeDef def;
   1122     InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
   1123                        {}, {});
   1124     ShapeHandle out;
   1125     EXPECT_EQ("Shape must be rank 1 but is rank 2",
   1126               c.MakeShapeFromShapeTensor(0, &out).error_message());
   1127   }
   1128 }
   1129 
   1130 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
   1131   NodeDef def;
   1132   std::vector<ShapeHandle> empty;
   1133   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1134 
   1135   // With an unknown rank.
   1136   ShapeHandle out;
   1137   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
   1138   EXPECT_EQ("?", c.DebugString(out));
   1139 
   1140   // With a known rank.
   1141   TF_ASSERT_OK(
   1142       c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
   1143   EXPECT_EQ("[0]", c.DebugString(out));
   1144   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
   1145       PartialTensorShape({0, -1, 1000}), &out));
   1146   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
   1147 }
   1148 
   1149 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
   1150   NodeDef def;
   1151   std::vector<ShapeHandle> empty;
   1152   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1153 
   1154   ShapeHandle out;
   1155   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
   1156   EXPECT_EQ("[]", c.DebugString(out));
   1157   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
   1158   EXPECT_EQ("[0]", c.DebugString(out));
   1159   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
   1160   EXPECT_EQ("[0,7,1000]", c.DebugString(out));
   1161 }
   1162 
   1163 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
   1164   NodeDef def;
   1165   std::vector<ShapeHandle> empty;
   1166   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1167   TensorShapeProto proto;
   1168 
   1169   // With a set unknown rank.
   1170   ShapeHandle out;
   1171   proto.set_unknown_rank(true);
   1172   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
   1173   EXPECT_EQ("?", c.DebugString(out));
   1174   proto.add_dim()->set_size(0);
   1175   EXPECT_TRUE(
   1176       StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message())
   1177           .contains("An unknown shape must not have any dimensions set."));
   1178   EXPECT_FALSE(IsSet(out));
   1179 
   1180   // With known rank.
   1181   proto.set_unknown_rank(false);
   1182   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
   1183   EXPECT_EQ("[0]", c.DebugString(out));
   1184   proto.add_dim()->set_size(-1);
   1185   proto.add_dim()->set_size(1000);
   1186   EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
   1187   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
   1188 
   1189   // With invalid dimension value.
   1190   proto.add_dim()->set_size(-2);
   1191   EXPECT_TRUE(
   1192       StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message())
   1193           .contains("Shape [0,?,1000,-2] has dimensions with values below -1 "
   1194                     "(where -1 means unknown)"));
   1195 
   1196   EXPECT_FALSE(IsSet(out));
   1197 }
   1198 
   1199 TEST_F(ShapeInferenceTest, MakeDim) {
   1200   NodeDef def;
   1201   std::vector<ShapeHandle> empty;
   1202   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1203 
   1204   auto d0 = c.MakeDim(1);
   1205   auto d1 = c.MakeDim(1);
   1206   auto d2 = c.MakeDim(2);
   1207   EXPECT_EQ("1", c.DebugString(d0));
   1208   EXPECT_EQ("1", c.DebugString(d1));
   1209   EXPECT_FALSE(SameHandle(d0, d1));
   1210   EXPECT_EQ("2", c.DebugString(d2));
   1211 }
   1212 
   1213 TEST_F(ShapeInferenceTest, UnknownDim) {
   1214   NodeDef def;
   1215   std::vector<ShapeHandle> empty;
   1216   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1217 
   1218   auto d0 = c.UnknownDim();
   1219   auto d1 = c.UnknownDim();
   1220   EXPECT_EQ("?", c.DebugString(d0));
   1221   EXPECT_EQ("?", c.DebugString(d1));
   1222   EXPECT_FALSE(SameHandle(d0, d1));
   1223 }
   1224 
   1225 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
   1226   NodeDef def;
   1227   std::vector<ShapeHandle> empty;
   1228   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1229 
   1230   auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
   1231   EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
   1232 
   1233   auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
   1234   EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
   1235 }
   1236 
   1237 TEST_F(ShapeInferenceTest, InputTensors) {
   1238   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
   1239   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
   1240   NodeDef def;
   1241   InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
   1242                      {&t1, &t2}, {}, {});
   1243 
   1244   EXPECT_TRUE(c.input_tensor(0) == &t1);
   1245   EXPECT_TRUE(c.input_tensor(1) == &t2);
   1246   EXPECT_TRUE(c.input_tensor(2) == nullptr);
   1247 }
   1248 
   1249 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
   1250   Tensor t1 = tensorflow::test::AsScalar<int32>(20);
   1251   Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
   1252   NodeDef def;
   1253   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
   1254                      {&t1, &t2}, {}, {});
   1255 
   1256   DimensionHandle d;
   1257   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
   1258   EXPECT_EQ("20", c.DebugString(d));
   1259 
   1260   EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message())
   1261                   .contains("Dimension size, given by scalar input 1, must "
   1262                             "be non-negative but is -1"));
   1263 
   1264   // Same tests, with int64 values.
   1265   t1 = tensorflow::test::AsScalar<int64>(20);
   1266   t2 = tensorflow::test::AsScalar<int64>(-1);
   1267   EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
   1268   EXPECT_EQ("20", c.DebugString(d));
   1269 
   1270   EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message())
   1271                   .contains("Dimension size, given by scalar input 1, must "
   1272                             "be non-negative but is -1"));
   1273 }
   1274 
   1275 TEST_F(ShapeInferenceTest, GetAttr) {
   1276   OpRegistrationData op_reg_data;
   1277   op_reg_data.op_def = MakeOpDef(0, 2);
   1278   NodeDef def;
   1279   CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
   1280             .Attr("foo", "bar")
   1281             .Finalize(&def)
   1282             .ok());
   1283 
   1284   std::vector<ShapeHandle> empty;
   1285   InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
   1286   string value;
   1287   EXPECT_TRUE(c.GetAttr("foo", &value).ok());
   1288   EXPECT_EQ("bar", value);
   1289 }
   1290 
   1291 TEST_F(ShapeInferenceTest, Divide) {
   1292   NodeDef def;
   1293   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
   1294                      {}, {});
   1295 
   1296   auto s = c.input(0);
   1297   auto d_6 = c.Dim(s, 0);
   1298   auto d_unknown = c.Dim(s, 1);
   1299   auto d_1 = c.Dim(s, 2);
   1300   auto d_2 = c.Dim(s, 3);
   1301   auto d_0 = c.Dim(s, 4);
   1302   bool evenly_divisible = true;
   1303 
   1304   // Dividing unknown by non-1 gives new unknown.
   1305   DimensionHandle out;
   1306   EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok());
   1307   EXPECT_EQ("?", c.DebugString(out));
   1308   EXPECT_FALSE(SameHandle(out, d_unknown));
   1309 
   1310   // Dividing anything by 1 returns the input.
   1311   EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok());
   1312   EXPECT_TRUE(SameHandle(out, d_unknown));
   1313   EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
   1314   EXPECT_TRUE(SameHandle(out, d_6));
   1315   EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
   1316   EXPECT_TRUE(SameHandle(out, d_unknown));
   1317   EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
   1318   EXPECT_TRUE(SameHandle(out, d_6));
   1319 
   1320   EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
   1321   EXPECT_EQ("3", c.DebugString(out));
   1322   EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
   1323   EXPECT_EQ("3", c.DebugString(out));
   1324 
   1325   EXPECT_TRUE(
   1326       StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message())
   1327           .contains("Dimension size must be evenly divisible by 5 but is 6"));
   1328 
   1329   EXPECT_TRUE(
   1330       StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
   1331           .contains("Divisor must be positive but is 0"));
   1332   EXPECT_TRUE(
   1333       StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message())
   1334           .contains("Divisor must be positive but is 0"));
   1335 
   1336   EXPECT_TRUE(
   1337       StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
   1338           .contains("Divisor must be positive but is -1"));
   1339 
   1340   // Repeat error cases above with evenly_divisible=false.
   1341   evenly_divisible = false;
   1342   EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
   1343   EXPECT_EQ("1", c.DebugString(out));
   1344 
   1345   EXPECT_TRUE(
   1346       StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
   1347           .contains("Divisor must be positive but is 0"));
   1348 
   1349   EXPECT_TRUE(
   1350       StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
   1351           .contains("Divisor must be positive but is -1"));
   1352 }
   1353 
   1354 TEST_F(ShapeInferenceTest, Add) {
   1355   NodeDef def;
   1356   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
   1357                      {});
   1358 
   1359   auto s = c.input(0);
   1360   auto d_6 = c.Dim(s, 0);
   1361   auto d_unknown = c.Dim(s, 1);
   1362   auto d_0 = c.Dim(s, 2);
   1363 
   1364   // Adding non-zero to unknown gives new unknown.
   1365   DimensionHandle out;
   1366   EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok());
   1367   EXPECT_EQ("?", c.DebugString(out));
   1368   EXPECT_FALSE(SameHandle(out, d_unknown));
   1369 
   1370   // Adding 0 to anything gives input.
   1371   EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok());
   1372   EXPECT_TRUE(SameHandle(out, d_unknown));
   1373   EXPECT_TRUE(c.Add(d_6, 0, &out).ok());
   1374   EXPECT_TRUE(SameHandle(out, d_6));
   1375 
   1376   // Adding dimension with value 0 to anything gives input.
   1377   EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
   1378   EXPECT_TRUE(SameHandle(out, d_unknown));
   1379   EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
   1380   EXPECT_TRUE(SameHandle(out, d_6));
   1381 
   1382   // Test addition.
   1383   EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
   1384   EXPECT_EQ("8", c.DebugString(out));
   1385   EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
   1386   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
   1387 
   1388   // Test addition using dimension as second value.
   1389   EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok());
   1390   EXPECT_EQ("8", c.DebugString(out));
   1391   EXPECT_TRUE(
   1392       c.Add(d_6, c.MakeDim(std::numeric_limits<int64>::max() - 6), &out).ok());
   1393   EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
   1394   EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
   1395   EXPECT_EQ("?", c.DebugString(out));
   1396   EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
   1397   EXPECT_TRUE(SameHandle(out, d_6));
   1398 
   1399   EXPECT_TRUE(
   1400       StringPiece(c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out)
   1401                       .error_message())
   1402           .contains(
   1403               "Dimension size overflow from adding 6 and 9223372036854775802"));
   1404 }
   1405 
   1406 TEST_F(ShapeInferenceTest, Subtract) {
   1407   NodeDef def;
   1408   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
   1409                      {}, {});
   1410 
   1411   auto s = c.input(0);
   1412   auto d_6 = c.Dim(s, 0);
   1413   auto d_unknown = c.Dim(s, 1);
   1414   auto d_0 = c.Dim(s, 2);
   1415   auto d_5 = c.Dim(s, 3);
   1416 
   1417   // Subtracting non-zero from unknown gives new unknown.
   1418   DimensionHandle out;
   1419   EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
   1420   EXPECT_EQ("?", c.DebugString(out));
   1421   EXPECT_FALSE(SameHandle(out, d_unknown));
   1422 
   1423   // Subtracting 0 from anything gives input.
   1424   EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
   1425   EXPECT_TRUE(SameHandle(out, d_unknown));
   1426   EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
   1427   EXPECT_TRUE(SameHandle(out, d_6));
   1428 
   1429   // Subtracting dimension with value 0 from anything gives input.
   1430   EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
   1431   EXPECT_TRUE(SameHandle(out, d_unknown));
   1432   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
   1433   EXPECT_TRUE(SameHandle(out, d_6));
   1434 
   1435   // Test subtraction.
   1436   EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
   1437   EXPECT_EQ("4", c.DebugString(out));
   1438   EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
   1439   EXPECT_EQ("0", c.DebugString(out));
   1440 
   1441   // Test subtraction using dimension as second value.
   1442   EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
   1443   EXPECT_EQ("4", c.DebugString(out));
   1444   EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
   1445   EXPECT_EQ("1", c.DebugString(out));
   1446   EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
   1447   EXPECT_EQ("?", c.DebugString(out));
   1448   EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
   1449   EXPECT_TRUE(SameHandle(out, d_6));
   1450 
   1451   EXPECT_TRUE(
   1452       StringPiece(c.Subtract(d_5, d_6, &out).error_message())
   1453           .contains("Negative dimension size caused by subtracting 6 from 5"));
   1454 }
   1455 
   1456 TEST_F(ShapeInferenceTest, Multiply) {
   1457   NodeDef def;
   1458   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
   1459                      {}, {});
   1460 
   1461   auto s = c.input(0);
   1462   auto d_6 = c.Dim(s, 0);
   1463   auto d_unknown = c.Dim(s, 1);
   1464   auto d_0 = c.Dim(s, 2);
   1465   auto d_1 = c.Dim(s, 3);
   1466 
   1467   // Multiplying non-zero to unknown gives new unknown.
   1468   DimensionHandle out;
   1469   EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
   1470   EXPECT_EQ("?", c.DebugString(out));
   1471 
   1472   // Multiplying 0 to anything gives 0.
   1473   EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok());
   1474   EXPECT_EQ("0", c.DebugString(out));
   1475   EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
   1476   EXPECT_EQ("0", c.DebugString(out));
   1477   EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
   1478   EXPECT_EQ("0", c.DebugString(out));
   1479 
   1480   // Multiplying 1 to anything gives the original.
   1481   // (unknown -> unknown)
   1482   EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
   1483   EXPECT_TRUE(SameHandle(d_unknown, out));
   1484   EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
   1485   EXPECT_TRUE(SameHandle(d_unknown, out));
   1486   EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
   1487   EXPECT_TRUE(SameHandle(d_unknown, out));
   1488   // (known -> known)
   1489   EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
   1490   EXPECT_TRUE(SameHandle(d_6, out));
   1491   EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
   1492   EXPECT_TRUE(SameHandle(d_6, out));
   1493   EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
   1494   EXPECT_TRUE(SameHandle(d_6, out));
   1495 
   1496   // Test multiplication.
   1497   EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
   1498   EXPECT_EQ("12", c.DebugString(out));
   1499   EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
   1500   EXPECT_EQ("36", c.DebugString(out));
   1501 
   1502   // Test multiplication using dimension as second value.
   1503   EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
   1504   EXPECT_EQ("12", c.DebugString(out));
   1505   EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
   1506   EXPECT_EQ("?", c.DebugString(out));
   1507 }
   1508 
   1509 TEST_F(ShapeInferenceTest, FullyDefined) {
   1510   NodeDef def;
   1511   std::vector<ShapeHandle> empty;
   1512   InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
   1513 
   1514   // No rank or missing dimension information should return false.
   1515   EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
   1516   EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
   1517 
   1518   // Return true if all information exists.
   1519   EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
   1520   EXPECT_TRUE(c.FullyDefined(c.Scalar()));
   1521 }
   1522 
   1523 TEST_F(ShapeInferenceTest, Min) {
   1524   NodeDef def;
   1525   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
   1526                      {}, {});
   1527 
   1528   auto s = c.input(0);
   1529   auto d_1 = c.Dim(s, 0);
   1530   auto d_2 = c.Dim(s, 1);
   1531   auto d_unknown = c.Dim(s, 2);
   1532   auto d_0 = c.Dim(s, 3);
   1533 
   1534   // Minimum involving zero and unknown returns zero.
   1535   DimensionHandle out;
   1536   EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
   1537   EXPECT_TRUE(SameHandle(d_0, out));
   1538   EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
   1539   EXPECT_TRUE(SameHandle(d_0, out));
   1540   EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
   1541   EXPECT_EQ("0", c.DebugString(out));
   1542   EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
   1543   EXPECT_EQ("0", c.DebugString(out));
   1544 
   1545   // Minimum involving unknowns and non-zeros gives new unknown.
   1546   EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
   1547   EXPECT_EQ("?", c.DebugString(out));
   1548   EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
   1549   EXPECT_EQ("?", c.DebugString(out));
   1550   EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
   1551   EXPECT_EQ("?", c.DebugString(out));
   1552 
   1553   // Minimum with constant second arg.
   1554   EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
   1555   EXPECT_TRUE(SameHandle(d_1, out));
   1556   EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
   1557   EXPECT_TRUE(SameHandle(d_1, out));
   1558   EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
   1559   EXPECT_EQ("1", c.DebugString(out));
   1560 
   1561   // Minimum with two dimensions.
   1562   EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
   1563   EXPECT_TRUE(SameHandle(d_1, out));
   1564   EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
   1565   EXPECT_TRUE(SameHandle(d_1, out));
   1566   EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
   1567   EXPECT_TRUE(SameHandle(d_1, out));
   1568   EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
   1569   EXPECT_TRUE(SameHandle(d_2, out));
   1570 }
   1571 
   1572 TEST_F(ShapeInferenceTest, Max) {
   1573   NodeDef def;
   1574   InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
   1575                      {});
   1576 
   1577   auto s = c.input(0);
   1578   auto d_1 = c.Dim(s, 0);
   1579   auto d_2 = c.Dim(s, 1);
   1580   auto d_unknown = c.Dim(s, 2);
   1581 
   1582   // Maximum involving unknowns gives new unknown.
   1583   DimensionHandle out;
   1584   EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
   1585   EXPECT_EQ("?", c.DebugString(out));
   1586   EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
   1587   EXPECT_EQ("?", c.DebugString(out));
   1588   EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
   1589   EXPECT_EQ("?", c.DebugString(out));
   1590 
   1591   // Maximum with constant second arg.
   1592   EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
   1593   EXPECT_TRUE(SameHandle(d_1, out));
   1594   EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
   1595   EXPECT_TRUE(SameHandle(d_2, out));
   1596   EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
   1597   EXPECT_EQ("3", c.DebugString(out));
   1598 
   1599   // Maximum with two dimensions.
   1600   EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
   1601   EXPECT_TRUE(SameHandle(d_1, out));
   1602   EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
   1603   EXPECT_TRUE(SameHandle(d_2, out));
   1604   EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
   1605   EXPECT_TRUE(SameHandle(d_2, out));
   1606   EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
   1607   EXPECT_TRUE(SameHandle(d_2, out));
   1608 }
   1609 
   1610 void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
   1611   NodeDef def;
   1612   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
   1613                      {});
   1614   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
   1615     ShapeHandle s;
   1616     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
   1617     return s;
   1618   };
   1619   auto get_shapes_and_types_from_context = [&](int idx) {
   1620     if (input_not_output) {
   1621       return c.input_handle_shapes_and_types(idx);
   1622     } else {
   1623       return c.output_handle_shapes_and_types(idx);
   1624     }
   1625   };
   1626   auto merge_shapes_and_types_to_context =
   1627       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
   1628         if (input_not_output) {
   1629           return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
   1630         } else {
   1631           return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
   1632         }
   1633       };
   1634 
   1635   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
   1636   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
   1637 
   1638   // First merge will take the input completely.
   1639   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
   1640                               {c.UnknownShape(), DT_INVALID},
   1641                               {make_shape({4, 3, 2, 1}), DT_INT32}};
   1642   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
   1643   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
   1644   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
   1645   ASSERT_EQ(3, v.size());
   1646   for (int i = 0; i < v.size(); ++i) {
   1647     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1648     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1649   }
   1650 
   1651   // Merge that fails because wrong number of values passed.
   1652   // Fails, and no changes made.
   1653   ASSERT_FALSE(merge_shapes_and_types_to_context(
   1654       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
   1655   v = *get_shapes_and_types_from_context(0);
   1656   ASSERT_EQ(3, v.size());
   1657   for (int i = 0; i < v.size(); ++i) {
   1658     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1659     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1660   }
   1661 
   1662   // Only difference is in a mismatched shape. That is ignored,
   1663   // and there are no other changes, so nothing is done.
   1664   //
   1665   // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
   1666   // return an error (separate error from 'refined' output)?
   1667   auto t2 = t;
   1668   t2[2].shape = make_shape({4, 3, 4, 1});
   1669   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
   1670   v = *get_shapes_and_types_from_context(0);
   1671   ASSERT_EQ(3, v.size());
   1672   for (int i = 0; i < v.size(); ++i) {
   1673     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1674     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1675   }
   1676 
   1677   // Only difference is in a mismatched dtype, but that cannot be
   1678   // updated unless original dtype is DT_INVALID.
   1679   t2 = t;
   1680   t2[2].dtype = DT_FLOAT;
   1681   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
   1682   v = *get_shapes_and_types_from_context(0);
   1683   ASSERT_EQ(3, v.size());
   1684   for (int i = 0; i < v.size(); ++i) {
   1685     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1686     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1687   }
   1688 
   1689   // Difference is mergeable (new shape).
   1690   t[1].shape = make_shape({1, 10});
   1691   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
   1692   v = *get_shapes_and_types_from_context(0);
   1693   ASSERT_EQ(3, v.size());
   1694   for (int i = 0; i < v.size(); ++i) {
   1695     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1696     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1697   }
   1698 
   1699   // Difference is mergeable (new type).
   1700   t[1].dtype = DT_DOUBLE;
   1701   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
   1702   v = *get_shapes_and_types_from_context(0);
   1703   ASSERT_EQ(3, v.size());
   1704   for (int i = 0; i < v.size(); ++i) {
   1705     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1706     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1707   }
   1708 
   1709   // No difference.
   1710   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
   1711 }
   1712 
   1713 TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
   1714   TestMergeHandles(true /* input_not_output */);
   1715 }
   1716 
   1717 TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
   1718   TestMergeHandles(false /* input_not_output */);
   1719 }
   1720 
   1721 void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
   1722   NodeDef def;
   1723   InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
   1724                      {});
   1725   auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
   1726     ShapeHandle s;
   1727     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
   1728     return s;
   1729   };
   1730   auto get_shapes_and_types_from_context = [&](int idx) {
   1731     if (input_not_output) {
   1732       return c.input_handle_shapes_and_types(idx);
   1733     } else {
   1734       return c.output_handle_shapes_and_types(idx);
   1735     }
   1736   };
   1737   auto relax_shapes_and_types_to_context =
   1738       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
   1739         if (input_not_output) {
   1740           return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
   1741         } else {
   1742           return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
   1743         }
   1744       };
   1745 
   1746   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
   1747   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
   1748 
   1749   // First relax will take the input completely.
   1750   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
   1751                               {c.UnknownShape(), DT_INVALID},
   1752                               {make_shape({4, 3, 2, 1}), DT_INT32}};
   1753   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
   1754   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
   1755   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
   1756   ASSERT_EQ(3, v.size());
   1757   for (int i = 0; i < v.size(); ++i) {
   1758     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1759     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1760   }
   1761 
   1762   // Relax that fails because wrong number of values passed.
   1763   // Fails, and no changes made.
   1764   ASSERT_FALSE(relax_shapes_and_types_to_context(
   1765       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
   1766   v = *get_shapes_and_types_from_context(0);
   1767   ASSERT_EQ(3, v.size());
   1768   for (int i = 0; i < v.size(); ++i) {
   1769     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
   1770     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1771   }
   1772 
   1773   // Only difference is in a mismatched shape. This should replace
   1774   // the mismatched dimension with an UnknownDim.
   1775   auto t2 = t;
   1776   t2[2].shape = make_shape({4, 3, 4, 1});
   1777   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
   1778   v = *get_shapes_and_types_from_context(0);
   1779   EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
   1780   for (int i = 0; i < v.size(); ++i) {
   1781     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1782   }
   1783 
   1784   // Only difference is in a mismatched dtype, but that cannot be
   1785   // updated unless original dtype is DT_INVALID.
   1786   t2 = t;
   1787   t2[2].dtype = DT_FLOAT;
   1788   ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
   1789   v = *get_shapes_and_types_from_context(0);
   1790   ASSERT_EQ(3, v.size());
   1791   for (int i = 0; i < v.size(); ++i) {
   1792     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1793   }
   1794 
   1795   // Difference is a new shape, which will result in a new UnknownShape.
   1796   t[1].shape = make_shape({1, 10});
   1797   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
   1798   v = *get_shapes_and_types_from_context(0);
   1799   ASSERT_EQ(3, v.size());
   1800   EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
   1801   EXPECT_EQ("?", c.DebugString(v[1].shape));
   1802   for (int i = 0; i < v.size(); ++i) {
   1803     EXPECT_EQ(t[i].dtype, v[i].dtype);
   1804   }
   1805 
   1806   // Difference is relaxable (new type).
   1807   t[1].dtype = DT_DOUBLE;
   1808   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
   1809   v = *get_shapes_and_types_from_context(0);
   1810   EXPECT_EQ(t[1].dtype, v[1].dtype);
   1811 }
   1812 
   1813 TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) {
   1814   TestRelaxHandles(true /* input_not_output */);
   1815 }
   1816 
   1817 TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) {
   1818   TestRelaxHandles(false /* input_not_output */);
   1819 }
   1820 
   1821 }  // namespace shape_inference
   1822 }  // namespace tensorflow
   1823