Home | History | Annotate | Download | only in ops
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/node_def_builder.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/shape_inference_testutil.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_shape.pb.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 #include "tensorflow/core/platform/test.h"
     25 
     26 namespace tensorflow {
     27 
     28 TEST(MathOpsTest, AddN_ShapeFn) {
     29   ShapeInferenceTestOp op("AddN");
     30   auto set_n = [&op](int n) {
     31     std::vector<NodeDefBuilder::NodeOut> src_list;
     32     src_list.reserve(n);
     33     for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
     34     TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
     35                      .Input(src_list)
     36                      .Attr("N", n)
     37                      .Finalize(&op.node_def));
     38   };
     39 
     40   set_n(2);
     41   // Adding two unknowns returns either input.
     42   INFER_OK(op, "?;?", "in0|in1");
     43 
     44   // known+unknown returns the known input.
     45   INFER_OK(op, "[1];[?]", "in0");
     46   INFER_OK(op, "[1];?", "in0");
     47   INFER_OK(op, "[?];[1]", "in1");
     48   INFER_OK(op, "?;[1]", "in1");
     49 
     50   set_n(2);
     51   INFER_OK(op, "[1,2];[?,2]", "in0");
     52   INFER_OK(op, "[1,2];[1,2]", "in0|in1");
     53   INFER_OK(op, "[?,2];[1,2]", "in1");
     54 
     55   set_n(3);
     56   INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
     57   INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
     58   INFER_OK(op, "?;?;[1,2]", "in2");
     59 
     60   set_n(2);
     61   INFER_OK(op, "?;[1,2]", "in1");
     62   INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
     63   INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
     64   INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
     65 
     66   set_n(3);
     67   INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
     68               "[1,2];?;[1,4]");
     69   INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
     70   set_n(4);
     71   INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
     72               "?;[1,2];?;[1,2,3]");
     73   INFER_ERROR("From merging shape 1 with other shapes.", op,
     74               "?;[1,2];?;[1,2,3]");
     75 }
     76 
     77 TEST(MathOpsTest, UnchangedShape_ShapeFn) {
     78   ShapeInferenceTestOp op("Cast");
     79   INFER_OK(op, "?", "in0");
     80   INFER_OK(op, "[?]", "in0");
     81   INFER_OK(op, "[1,?,3,4]", "in0");
     82 }
     83 
     84 TEST(MathOpsTest, Segment_ShapeFn) {
     85   // Tests SegmentReductionShapeFn.
     86   for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
     87                               "SegmentProd", "SegmentSum"}) {
     88     ShapeInferenceTestOp op(op_name);
     89     INFER_OK(op, "?;?", "?");
     90     INFER_OK(op, "?;[100]", "?");
     91 
     92     // Data shape with single dimension.
     93     INFER_OK(op, "[?];?", "[?]");
     94     INFER_OK(op, "[?];[100]", "[?]");
     95     INFER_OK(op, "[1];?", "[?]");
     96     INFER_OK(op, "[1];[100]", "[?]");
     97 
     98     // Data shape with multiple dimensions.
     99     INFER_OK(op, "[?,?];?", "[?,d0_1]");
    100     INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
    101     INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
    102     INFER_OK(op, "[1,?];?", "[?,d0_1]");
    103     INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
    104     INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
    105 
    106     // Error cases.
    107     INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
    108     INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
    109   }
    110 }
    111 
    112 TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
    113   for (const auto* op_name : {"Add",        "Complex",
    114                               "Div",        "Equal",
    115                               "Greater",    "GreaterEqual",
    116                               "Igamma",     "Igammac",
    117                               "Zeta",       "Polygamma",
    118                               "Less",       "LessEqual",
    119                               "LogicalAnd", "LogicalOr",
    120                               "Maximum",    "Minimum",
    121                               "Mod",        "Mul",
    122                               "NotEqual",   "Pow",
    123                               "Sub",        "SquaredDifference",
    124                               "DivNoNan"}) {
    125     ShapeInferenceTestOp op(op_name);
    126     INFER_OK(op, "?;?", "?");
    127     INFER_OK(op, "[1,2];?", "?");
    128     INFER_OK(op, "?;[1,2]", "?");
    129 
    130     INFER_OK(op, "[?];[1]", "[d0_0]");
    131     INFER_OK(op, "[1];[?]", "[d1_0]");
    132     INFER_OK(op, "[?];[2]", "[d1_0]");
    133     INFER_OK(op, "[2];[?]", "[d0_0]");
    134     INFER_OK(op, "[?];[?]", "[?]");
    135     INFER_OK(op, "[];[?]", "[d1_0]");
    136     INFER_OK(op, "[?];[]", "[d0_0]");
    137 
    138     INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
    139     INFER_OK(op, "[];[1]", "[d1_0]");
    140     INFER_OK(op, "[1];[]", "[d0_0]");
    141 
    142     INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
    143     INFER_OK(op, "[];[2]", "[d1_0]");
    144     INFER_OK(op, "[1];[2]", "[d1_0]");
    145     INFER_OK(op, "[2];[1]", "[d0_0]");
    146     INFER_OK(op, "[2];[]", "[d0_0]");
    147     INFER_OK(op, "[2];[?]", "[d0_0]");
    148 
    149     INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
    150     INFER_OK(op, "[];[0]", "[d1_0]");
    151     INFER_OK(op, "[1];[0]", "[d1_0]");
    152     INFER_OK(op, "[0];[1]", "[d0_0]");
    153     INFER_OK(op, "[0];[]", "[d0_0]");
    154 
    155     INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]");
    156     INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]");
    157 
    158     // Multiple dimension cases (same test cases, switching x and y).
    159     INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
    160              "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
    161     INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
    162              "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
    163   }
    164 }
    165 
    166 TEST(MathOpsTest, Select_ShapeFn) {
    167   ShapeInferenceTestOp op("Select");
    168   INFER_OK(op, "?;?;?", "in1|in2");
    169 
    170   // scalar case
    171   INFER_OK(op, "[];[1];?", "in1");
    172   INFER_OK(op, "[];?;?", "in1|in2");
    173 
    174   INFER_OK(op, "[1];?;?",
    175            "in1|in2");  // When cond is vector, t/e may not match it.
    176   INFER_OK(op, "[1,2];?;?", "in1|in2?");
    177 
    178   INFER_OK(op, "?;[];?", "in1");
    179   INFER_OK(op, "?;?;[]", "in2");
    180   INFER_OK(op, "?;[1];?", "in1");
    181   INFER_OK(op, "?;?;[1]", "in2");
    182   INFER_OK(op, "?;[1,2];?", "in1");
    183   INFER_OK(op, "?;?;[1,2]", "in2");
    184 
    185   INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
    186   INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
    187   INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
    188   INFER_OK(op, "[2];[?];[?]", "in1|in2");
    189 
    190   INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
    191   INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
    192   INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
    193   INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
    194               "[2,?];[?,?,3];[?,2,?]");
    195   INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
    196   INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
    197               "[2,?,5];[?,?,3];[?,2,?]");
    198 
    199   // Test that handles were merged.
    200   //
    201   // Tests below will modify handle_data and call run_inference_for_handles to
    202   // rerun shape inference, updating the context <c>.
    203   const OpRegistrationData* op_reg_data;
    204   TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
    205   typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV;
    206   std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
    207   std::unique_ptr<shape_inference::InferenceContext> c;
    208   auto run_inference_for_handles = [&]() -> Status {
    209     CHECK(op_reg_data->shape_inference_fn != nullptr);
    210     c.reset(new shape_inference::InferenceContext(
    211         TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
    212         {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
    213         handle_data));
    214     TF_CHECK_OK(c->construction_status());
    215     Status s = c->Run(op_reg_data->shape_inference_fn);
    216     LOG(INFO) << "Inference got " << s;
    217     return s;
    218   };
    219   auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
    220     TensorShapeProto p;
    221     for (auto i : dim_sizes) p.add_dim()->set_size(i);
    222     return p;
    223   };
    224 
    225   TensorShapeProto i0 = shape_proto({1, -1});
    226   TensorShapeProto i1 = shape_proto({-1, 2});
    227   TensorShapeProto unknown_shape;
    228   unknown_shape.set_unknown_rank(true);
    229   TensorShapeProto scalar;
    230 
    231   handle_data.emplace_back(
    232       new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
    233   handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
    234   handle_data.emplace_back(
    235       new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
    236 
    237   TF_ASSERT_OK(run_inference_for_handles());
    238   auto* out = c->output_handle_shapes_and_types(0);
    239   ASSERT_EQ(2, out->size());
    240   EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
    241   EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
    242   EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
    243   EXPECT_EQ(DT_INT32, out->at(1).dtype);
    244 
    245   // Expect an error when the shapes can't be merged.
    246   handle_data[2]->at(0).first = shape_proto({2, 2});
    247   EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
    248                                     "must be equal, but are 1 and 2"));
    249   handle_data[2]->at(0).first = i1;  // restore to valid
    250 
    251   // Expect an error when the types can't be merged.
    252   handle_data[2]->at(1).second = DT_INT64;
    253   EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
    254                                     "pointing to different dtypes"));
    255   handle_data[2]->at(1).second = DT_INT32;  // restore to valid
    256 
    257   // Expect an error when different numbers of tensors are merged.
    258   handle_data[2]->push_back({i1, DT_FLOAT});
    259   EXPECT_TRUE(
    260       str_util::StrContains(run_inference_for_handles().error_message(),
    261                             "pointing to different numbers of tensors"));
    262   handle_data[2]->pop_back();  // restore to valid.
    263 }
    264 
    265 TEST(MathOpsTest, Range_ShapeFn) {
    266   ShapeInferenceTestOp op("Range");
    267 
    268   TF_ASSERT_OK(NodeDefBuilder("test", "Range")
    269                    .Input({"start", {}, DT_INT32})
    270                    .Input({"limit", {}, DT_INT32})
    271                    .Input({"delta", {}, DT_INT32})
    272                    .Attr("Tidx", DT_INT32)
    273                    .Finalize(&op.node_def));
    274 
    275   op.input_tensors.resize(3);
    276   INFER_OK(op, "?;?;?", "[?]");
    277   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
    278   INFER_ERROR("for 'start'", op, "[1,2];?;?");
    279 
    280   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
    281   INFER_ERROR("for 'limit'", op, "?;[1,2];?");
    282 
    283   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
    284   INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
    285 
    286   Tensor start_t = test::AsScalar(1);
    287   op.input_tensors[0] = &start_t;
    288   INFER_OK(op, "?;?;?", "[?]");
    289   Tensor limit_t = test::AsScalar(1);
    290   op.input_tensors[1] = &limit_t;
    291   INFER_OK(op, "?;?;?", "[?]");
    292 
    293   Tensor delta_t = test::AsScalar(1);
    294   op.input_tensors[2] = &delta_t;
    295   INFER_OK(op, "?;?;?", "[0]");
    296 
    297   delta_t = test::AsScalar(0);
    298   INFER_ERROR("Requires delta != 0", op, "?;?;?");
    299   delta_t = test::AsScalar(3);
    300 
    301   limit_t = test::AsScalar(-1);
    302   INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
    303 
    304   delta_t = test::AsScalar(-1);
    305   INFER_OK(op, "?;?;?", "[2]");
    306 
    307   limit_t = test::AsScalar(4);
    308   INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
    309 
    310   limit_t = test::AsScalar(100);
    311   start_t = test::AsScalar(2);
    312   delta_t = test::AsScalar(3);
    313   INFER_OK(op, "?;?;?", "[33]");
    314 }
    315 
    316 TEST(MathOpsTest, LinSpace_ShapeFn) {
    317   ShapeInferenceTestOp op("LinSpace");
    318   op.input_tensors.resize(3);
    319   INFER_OK(op, "?;?;?", "[?]");
    320   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
    321   INFER_ERROR("for 'start'", op, "[1,2];?;?");
    322   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
    323   INFER_ERROR("for 'stop'", op, "?;[1,2];?");
    324   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
    325   INFER_ERROR("for 'num'", op, "?;?;[1,2]");
    326 
    327   Tensor num_t = test::AsScalar(1);
    328   op.input_tensors[2] = &num_t;
    329   INFER_OK(op, "?;?;?", "[1]");
    330   num_t = test::AsScalar(2);
    331   INFER_OK(op, "?;?;?", "[2]");
    332   num_t = test::AsScalar(-1);
    333   INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
    334 }
    335 
    336 TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
    337   ShapeInferenceTestOp op("UnsortedSegmentSum");
    338   op.input_tensors.resize(3);
    339   INFER_OK(op, "?;?;?", "?");
    340   INFER_OK(op, "?;[?];?", "?");
    341   INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
    342   INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
    343               "[1,?,2];[1,?,3];?");
    344   INFER_OK(op, "?;[3];?", "?");
    345   INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
    346               "[1,2];[1,2,3];?");
    347 
    348   Tensor num_segments_t = test::AsScalar(100);
    349   op.input_tensors[2] = &num_segments_t;
    350   INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
    351 
    352   num_segments_t = test::AsScalar(-1);
    353   INFER_ERROR(("Dimension size, given by scalar input 2, must be "
    354                "non-negative but is -1"),
    355               op, "[3];[3];?");
    356 }
    357 
    358 TEST(MathOpsTest, SparseSegment_ShapeFn) {
    359   ShapeInferenceTestOp op("SparseSegmentSum");
    360   op.input_tensors.resize(3);
    361   INFER_OK(op, "?;?;?", "?");
    362   INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
    363 
    364   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
    365   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
    366 
    367   INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
    368               "[2,4,3];[3];[4]");
    369 }
    370 
    371 TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
    372   ShapeInferenceTestOp op("SparseSegmentMeanGrad");
    373   op.input_tensors.resize(4);
    374   INFER_OK(op, "?;?;?;?", "?");
    375   INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
    376 
    377   Tensor num_segments_t = test::AsScalar(100);
    378   op.input_tensors[3] = &num_segments_t;
    379   INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
    380 
    381   INFER_ERROR("Shape must be rank 0 but is rank 2", op,
    382               "[2,4,3];[3];[3];[1,1]");
    383 
    384   // Negative value is not allowed
    385   num_segments_t = test::AsScalar(-100);
    386   op.input_tensors[3] = &num_segments_t;
    387   INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
    388 }
    389 
    390 TEST(MathOpsTest, BatchMatMul_ShapeFn) {
    391   ShapeInferenceTestOp op("BatchMatMul");
    392   auto set_adj = [&op](bool adj_x, bool adj_y) {
    393     TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
    394                      .Input({"a", 0, DT_FLOAT})
    395                      .Input({"b", 0, DT_FLOAT})
    396                      .Attr("adj_x", adj_x)
    397                      .Attr("adj_y", adj_y)
    398                      .Finalize(&op.node_def));
    399   };
    400 
    401   set_adj(false, false);
    402 
    403   // Rank checks.
    404   INFER_ERROR("at least rank 2", op, "[1];?");
    405   INFER_ERROR("at least rank 2", op, "?;[2]");
    406 
    407   INFER_OK(op, "?;?", "?");
    408 
    409   // 0 batch dims.
    410   INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
    411 
    412   // 2 batch dims.
    413   INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
    414 
    415   // Test adj_a, testing output and that inner dims are compared.
    416   set_adj(false, false);
    417   INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
    418   INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]");  // inner dim mismatch
    419   set_adj(true, false);
    420   INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
    421   INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]");  // inner dim mismatch
    422 
    423   // Test adj_b=true.
    424   set_adj(false, true);
    425   INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
    426   INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]");  // inner dim mismatch
    427   set_adj(true, true);
    428   INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
    429   INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]");  // inner dim mismatch
    430 }
    431 
    432 TEST(MathOpsTest, ArgOps_ShapeFn) {
    433   ShapeInferenceTestOp op("ArgMax");
    434   op.input_tensors.resize(2);
    435 
    436   INFER_OK(op, "?;?", "?");
    437 
    438   // input rank <= 1 produces scalar
    439   INFER_OK(op, "[2];?", "[]");
    440   INFER_OK(op, "[];?", "[]");
    441 
    442   // Incorrect rank for dimension
    443   INFER_ERROR("must be rank 0", op, "[2];[1]");
    444 
    445   // dimension not available, but input rank is.  Output is unknown
    446   // shape with rank one less than input rank.
    447   INFER_OK(op, "[2,3,4];?", "[?,?]");
    448   INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
    449 
    450   // Dimension values known
    451   Tensor dimension = test::AsScalar(0);
    452   op.input_tensors[1] = &dimension;
    453   INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
    454 
    455   dimension = test::AsScalar(1);
    456   op.input_tensors[1] = &dimension;
    457   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
    458 
    459   dimension = test::AsScalar(2);
    460   op.input_tensors[1] = &dimension;
    461   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
    462 
    463   // Dimension value out of bounds
    464   dimension = test::AsScalar(10);
    465   op.input_tensors[1] = &dimension;
    466   INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
    467 
    468   dimension = test::AsScalar(-10);
    469   op.input_tensors[1] = &dimension;
    470   INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
    471 
    472   dimension = test::AsScalar(-1);
    473   op.input_tensors[1] = &dimension;
    474   INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
    475 }
    476 
    477 TEST(MathOpsTest, Betainc_ShapeFn) {
    478   ShapeInferenceTestOp op("Betainc");
    479 
    480   INFER_OK(op, "?;?;?", "?");
    481   INFER_OK(op, "[?,?];?;?", "in0");
    482   INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
    483   INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
    484 
    485   INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
    486   INFER_OK(op, "[];[];[?,?,3]", "in2");
    487 
    488   // All but one is a scalar, so use it.
    489   INFER_OK(op, "[];[];?", "in2");
    490   INFER_OK(op, "[];[];[1,2,3,4]", "in2");
    491 
    492   // All scalar input; implementation picks in0.
    493   INFER_OK(op, "[];[];[]", "in0");
    494 
    495   // Non-scalars must match shape.
    496   INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
    497   INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
    498 }
    499 
    500 TEST(MathOpsTest, Requantize_ShapeFn) {
    501   ShapeInferenceTestOp op("Requantize");
    502 
    503   INFER_OK(op, "?;?;?;?;?", "in0;[];[]");
    504   INFER_OK(op, "?;[];[];[];[]", "in0;[];[]");
    505 
    506   // Rank checks on input scalars.
    507   INFER_ERROR("must be rank 0", op, "?;[1];?;?;?");
    508   INFER_ERROR("must be rank 0", op, "?;?;[2];?;?");
    509   INFER_ERROR("must be rank 0", op, "?;?;?;[3];?");
    510   INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]");
    511 }
    512 
    513 TEST(MathOpstest, RequantizationRange_ShapeFn) {
    514   ShapeInferenceTestOp op("RequantizationRange");
    515 
    516   INFER_OK(op, "?;?;?", "[];[]");
    517   INFER_OK(op, "?;[];[]", "[];[]");
    518 
    519   // Rank checks on input scalars.
    520   INFER_ERROR("must be rank 0", op, "?;[1];?");
    521   INFER_ERROR("must be rank 0", op, "?;?;[2]");
    522 }
    523 
    524 TEST(MathOpsTest, Cross_ShapeFn) {
    525   ShapeInferenceTestOp op("Cross");
    526 
    527   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
    528   INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
    529   INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
    530 
    531   INFER_OK(op, "?;?", "in0");
    532   INFER_OK(op, "[?];[?]", "in0");
    533   INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
    534 }
    535 
    536 TEST(MathOpsTest, HistogramFixedWidth_ShapeFn) {
    537   ShapeInferenceTestOp op("HistogramFixedWidth");
    538 
    539   // value_range should be vector.
    540   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];[];[]");
    541   // value_range should have 2 elements.
    542   INFER_ERROR("Dimension must be 2 but is 3", op, "[];[3];[]");
    543   // nbins should be scalar.
    544   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[2];[2]");
    545 
    546   INFER_OK(op, "?;?;?", "[?]");
    547   INFER_OK(op, "[?];[2];[]", "[?]");
    548   INFER_OK(op, "[?];[2];?", "[?]");
    549 }
    550 
    551 TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
    552   ShapeInferenceTestOp op("QuantizedAdd");
    553 
    554   INFER_OK(op, "?;?;?;?;?;?", "?;[];[]");
    555   INFER_OK(op, "?;?;[];[];[];[]", "?;[];[]");
    556   INFER_OK(op, "[1,2];?;[];[];[];[]", "?;[];[]");
    557   INFER_OK(op, "[];[2];[];[];[];[]", "[d1_0];[];[]");
    558 
    559   // Rank checks on input scalars.
    560   INFER_ERROR("must be rank 0", op, "?;?;[1];?;?;?");
    561   INFER_ERROR("must be rank 0", op, "?;?;?;[2];?;?");
    562   INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
    563   INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
    564 }
    565 
    566 TEST(MathOpsTest, Bincount_ShapeFn) {
    567   ShapeInferenceTestOp op("Bincount");
    568 
    569   // size should be scalar.
    570   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
    571 
    572   INFER_OK(op, "?;?;?", "[?]");
    573   INFER_OK(op, "?;[];?", "[?]");
    574   INFER_OK(op, "[?];[];?", "[?]");
    575   INFER_OK(op, "[?];[];[?]", "[?]");
    576 }
    577 }  // end namespace tensorflow
    578