1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/node_def_builder.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference_testutil.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/platform/test.h" 25 26 namespace tensorflow { 27 28 TEST(MathOpsTest, AddN_ShapeFn) { 29 ShapeInferenceTestOp op("AddN"); 30 auto set_n = [&op](int n) { 31 std::vector<NodeDefBuilder::NodeOut> src_list; 32 src_list.reserve(n); 33 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); 34 TF_ASSERT_OK(NodeDefBuilder("test", "AddN") 35 .Input(src_list) 36 .Attr("N", n) 37 .Finalize(&op.node_def)); 38 }; 39 40 set_n(2); 41 // Adding two unknowns returns either input. 42 INFER_OK(op, "?;?", "in0|in1"); 43 44 // known+unknown returns the known input. 45 INFER_OK(op, "[1];[?]", "in0"); 46 INFER_OK(op, "[1];?", "in0"); 47 INFER_OK(op, "[?];[1]", "in1"); 48 INFER_OK(op, "?;[1]", "in1"); 49 50 set_n(2); 51 INFER_OK(op, "[1,2];[?,2]", "in0"); 52 INFER_OK(op, "[1,2];[1,2]", "in0|in1"); 53 INFER_OK(op, "[?,2];[1,2]", "in1"); 54 55 set_n(3); 56 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2"); 57 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0"); 58 INFER_OK(op, "?;?;[1,2]", "in2"); 59 60 set_n(2); 61 INFER_OK(op, "?;[1,2]", "in1"); 62 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]"); 63 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]"); 64 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]"); 65 66 set_n(3); 67 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op, 68 "[1,2];?;[1,4]"); 69 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]"); 70 set_n(4); 71 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, 72 "?;[1,2];?;[1,2,3]"); 73 INFER_ERROR("From merging shape 1 with other shapes.", op, 74 "?;[1,2];?;[1,2,3]"); 75 } 76 77 TEST(MathOpsTest, UnchangedShape_ShapeFn) { 78 ShapeInferenceTestOp op("Cast"); 79 INFER_OK(op, "?", "in0"); 80 INFER_OK(op, "[?]", "in0"); 81 INFER_OK(op, "[1,?,3,4]", "in0"); 82 } 83 84 TEST(MathOpsTest, Segment_ShapeFn) { 85 // Tests SegmentReductionShapeFn. 86 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin", 87 "SegmentProd", "SegmentSum"}) { 88 ShapeInferenceTestOp op(op_name); 89 INFER_OK(op, "?;?", "?"); 90 INFER_OK(op, "?;[100]", "?"); 91 92 // Data shape with single dimension. 93 INFER_OK(op, "[?];?", "[?]"); 94 INFER_OK(op, "[?];[100]", "[?]"); 95 INFER_OK(op, "[1];?", "[?]"); 96 INFER_OK(op, "[1];[100]", "[?]"); 97 98 // Data shape with multiple dimensions. 99 INFER_OK(op, "[?,?];?", "[?,d0_1]"); 100 INFER_OK(op, "[?,2];[100]", "[?,d0_1]"); 101 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 102 INFER_OK(op, "[1,?];?", "[?,d0_1]"); 103 INFER_OK(op, "[1,2];[100]", "[?,d0_1]"); 104 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 105 106 // Error cases. 107 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); 108 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]"); 109 } 110 } 111 112 TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { 113 for (const auto* op_name : {"Add", "Complex", 114 "Div", "Equal", 115 "Greater", "GreaterEqual", 116 "Igamma", "Igammac", 117 "Zeta", "Polygamma", 118 "Less", "LessEqual", 119 "LogicalAnd", "LogicalOr", 120 "Maximum", "Minimum", 121 "Mod", "Mul", 122 "NotEqual", "Pow", 123 "Sub", "SquaredDifference", 124 "DivNoNan"}) { 125 ShapeInferenceTestOp op(op_name); 126 INFER_OK(op, "?;?", "?"); 127 INFER_OK(op, "[1,2];?", "?"); 128 INFER_OK(op, "?;[1,2]", "?"); 129 130 INFER_OK(op, "[?];[1]", "[d0_0]"); 131 INFER_OK(op, "[1];[?]", "[d1_0]"); 132 INFER_OK(op, "[?];[2]", "[d1_0]"); 133 INFER_OK(op, "[2];[?]", "[d0_0]"); 134 INFER_OK(op, "[?];[?]", "[?]"); 135 INFER_OK(op, "[];[?]", "[d1_0]"); 136 INFER_OK(op, "[?];[]", "[d0_0]"); 137 138 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]"); 139 INFER_OK(op, "[];[1]", "[d1_0]"); 140 INFER_OK(op, "[1];[]", "[d0_0]"); 141 142 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]"); 143 INFER_OK(op, "[];[2]", "[d1_0]"); 144 INFER_OK(op, "[1];[2]", "[d1_0]"); 145 INFER_OK(op, "[2];[1]", "[d0_0]"); 146 INFER_OK(op, "[2];[]", "[d0_0]"); 147 INFER_OK(op, "[2];[?]", "[d0_0]"); 148 149 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); 150 INFER_OK(op, "[];[0]", "[d1_0]"); 151 INFER_OK(op, "[1];[0]", "[d1_0]"); 152 INFER_OK(op, "[0];[1]", "[d0_0]"); 153 INFER_OK(op, "[0];[]", "[d0_0]"); 154 155 INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]"); 156 INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]"); 157 158 // Multiple dimension cases (same test cases, switching x and y). 159 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", 160 "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); 161 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", 162 "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); 163 } 164 } 165 166 TEST(MathOpsTest, Select_ShapeFn) { 167 ShapeInferenceTestOp op("Select"); 168 INFER_OK(op, "?;?;?", "in1|in2"); 169 170 // scalar case 171 INFER_OK(op, "[];[1];?", "in1"); 172 INFER_OK(op, "[];?;?", "in1|in2"); 173 174 INFER_OK(op, "[1];?;?", 175 "in1|in2"); // When cond is vector, t/e may not match it. 176 INFER_OK(op, "[1,2];?;?", "in1|in2?"); 177 178 INFER_OK(op, "?;[];?", "in1"); 179 INFER_OK(op, "?;?;[]", "in2"); 180 INFER_OK(op, "?;[1];?", "in1"); 181 INFER_OK(op, "?;?;[1]", "in2"); 182 INFER_OK(op, "?;[1,2];?", "in1"); 183 INFER_OK(op, "?;?;[1,2]", "in2"); 184 185 INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?"); 186 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]"); 187 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?"); 188 INFER_OK(op, "[2];[?];[?]", "in1|in2"); 189 190 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]"); 191 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]"); 192 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]"); 193 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op, 194 "[2,?];[?,?,3];[?,2,?]"); 195 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]"); 196 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op, 197 "[2,?,5];[?,?,3];[?,2,?]"); 198 199 // Test that handles were merged. 200 // 201 // Tests below will modify handle_data and call run_inference_for_handles to 202 // rerun shape inference, updating the context <c>. 203 const OpRegistrationData* op_reg_data; 204 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); 205 typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV; 206 std::vector<std::unique_ptr<ShapeDtypeV>> handle_data; 207 std::unique_ptr<shape_inference::InferenceContext> c; 208 auto run_inference_for_handles = [&]() -> Status { 209 CHECK(op_reg_data->shape_inference_fn != nullptr); 210 c.reset(new shape_inference::InferenceContext( 211 TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, 212 {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, 213 handle_data)); 214 TF_CHECK_OK(c->construction_status()); 215 Status s = c->Run(op_reg_data->shape_inference_fn); 216 LOG(INFO) << "Inference got " << s; 217 return s; 218 }; 219 auto shape_proto = [](std::initializer_list<int64> dim_sizes) { 220 TensorShapeProto p; 221 for (auto i : dim_sizes) p.add_dim()->set_size(i); 222 return p; 223 }; 224 225 TensorShapeProto i0 = shape_proto({1, -1}); 226 TensorShapeProto i1 = shape_proto({-1, 2}); 227 TensorShapeProto unknown_shape; 228 unknown_shape.set_unknown_rank(true); 229 TensorShapeProto scalar; 230 231 handle_data.emplace_back( 232 new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}}); 233 handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}}); 234 handle_data.emplace_back( 235 new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}}); 236 237 TF_ASSERT_OK(run_inference_for_handles()); 238 auto* out = c->output_handle_shapes_and_types(0); 239 ASSERT_EQ(2, out->size()); 240 EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape)); 241 EXPECT_EQ(DT_FLOAT, out->at(0).dtype); 242 EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape)); 243 EXPECT_EQ(DT_INT32, out->at(1).dtype); 244 245 // Expect an error when the shapes can't be merged. 246 handle_data[2]->at(0).first = shape_proto({2, 2}); 247 EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(), 248 "must be equal, but are 1 and 2")); 249 handle_data[2]->at(0).first = i1; // restore to valid 250 251 // Expect an error when the types can't be merged. 252 handle_data[2]->at(1).second = DT_INT64; 253 EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(), 254 "pointing to different dtypes")); 255 handle_data[2]->at(1).second = DT_INT32; // restore to valid 256 257 // Expect an error when different numbers of tensors are merged. 258 handle_data[2]->push_back({i1, DT_FLOAT}); 259 EXPECT_TRUE( 260 str_util::StrContains(run_inference_for_handles().error_message(), 261 "pointing to different numbers of tensors")); 262 handle_data[2]->pop_back(); // restore to valid. 263 } 264 265 TEST(MathOpsTest, Range_ShapeFn) { 266 ShapeInferenceTestOp op("Range"); 267 268 TF_ASSERT_OK(NodeDefBuilder("test", "Range") 269 .Input({"start", {}, DT_INT32}) 270 .Input({"limit", {}, DT_INT32}) 271 .Input({"delta", {}, DT_INT32}) 272 .Attr("Tidx", DT_INT32) 273 .Finalize(&op.node_def)); 274 275 op.input_tensors.resize(3); 276 INFER_OK(op, "?;?;?", "[?]"); 277 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 278 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 279 280 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 281 INFER_ERROR("for 'limit'", op, "?;[1,2];?"); 282 283 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 284 INFER_ERROR("for 'delta'", op, "?;?;[1,2]"); 285 286 Tensor start_t = test::AsScalar(1); 287 op.input_tensors[0] = &start_t; 288 INFER_OK(op, "?;?;?", "[?]"); 289 Tensor limit_t = test::AsScalar(1); 290 op.input_tensors[1] = &limit_t; 291 INFER_OK(op, "?;?;?", "[?]"); 292 293 Tensor delta_t = test::AsScalar(1); 294 op.input_tensors[2] = &delta_t; 295 INFER_OK(op, "?;?;?", "[0]"); 296 297 delta_t = test::AsScalar(0); 298 INFER_ERROR("Requires delta != 0", op, "?;?;?"); 299 delta_t = test::AsScalar(3); 300 301 limit_t = test::AsScalar(-1); 302 INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?"); 303 304 delta_t = test::AsScalar(-1); 305 INFER_OK(op, "?;?;?", "[2]"); 306 307 limit_t = test::AsScalar(4); 308 INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?"); 309 310 limit_t = test::AsScalar(100); 311 start_t = test::AsScalar(2); 312 delta_t = test::AsScalar(3); 313 INFER_OK(op, "?;?;?", "[33]"); 314 } 315 316 TEST(MathOpsTest, LinSpace_ShapeFn) { 317 ShapeInferenceTestOp op("LinSpace"); 318 op.input_tensors.resize(3); 319 INFER_OK(op, "?;?;?", "[?]"); 320 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 321 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 322 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 323 INFER_ERROR("for 'stop'", op, "?;[1,2];?"); 324 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 325 INFER_ERROR("for 'num'", op, "?;?;[1,2]"); 326 327 Tensor num_t = test::AsScalar(1); 328 op.input_tensors[2] = &num_t; 329 INFER_OK(op, "?;?;?", "[1]"); 330 num_t = test::AsScalar(2); 331 INFER_OK(op, "?;?;?", "[2]"); 332 num_t = test::AsScalar(-1); 333 INFER_ERROR("Requires num > 0: -1", op, "?;?;?"); 334 } 335 336 TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) { 337 ShapeInferenceTestOp op("UnsortedSegmentSum"); 338 op.input_tensors.resize(3); 339 INFER_OK(op, "?;?;?", "?"); 340 INFER_OK(op, "?;[?];?", "?"); 341 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 342 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, 343 "[1,?,2];[1,?,3];?"); 344 INFER_OK(op, "?;[3];?", "?"); 345 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, 346 "[1,2];[1,2,3];?"); 347 348 Tensor num_segments_t = test::AsScalar(100); 349 op.input_tensors[2] = &num_segments_t; 350 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]"); 351 352 num_segments_t = test::AsScalar(-1); 353 INFER_ERROR(("Dimension size, given by scalar input 2, must be " 354 "non-negative but is -1"), 355 op, "[3];[3];?"); 356 } 357 358 TEST(MathOpsTest, SparseSegment_ShapeFn) { 359 ShapeInferenceTestOp op("SparseSegmentSum"); 360 op.input_tensors.resize(3); 361 INFER_OK(op, "?;?;?", "?"); 362 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]"); 363 364 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]"); 365 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]"); 366 367 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op, 368 "[2,4,3];[3];[4]"); 369 } 370 371 TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) { 372 ShapeInferenceTestOp op("SparseSegmentMeanGrad"); 373 op.input_tensors.resize(4); 374 INFER_OK(op, "?;?;?;?", "?"); 375 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]"); 376 377 Tensor num_segments_t = test::AsScalar(100); 378 op.input_tensors[3] = &num_segments_t; 379 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]"); 380 381 INFER_ERROR("Shape must be rank 0 but is rank 2", op, 382 "[2,4,3];[3];[3];[1,1]"); 383 384 // Negative value is not allowed 385 num_segments_t = test::AsScalar(-100); 386 op.input_tensors[3] = &num_segments_t; 387 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]"); 388 } 389 390 TEST(MathOpsTest, BatchMatMul_ShapeFn) { 391 ShapeInferenceTestOp op("BatchMatMul"); 392 auto set_adj = [&op](bool adj_x, bool adj_y) { 393 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul") 394 .Input({"a", 0, DT_FLOAT}) 395 .Input({"b", 0, DT_FLOAT}) 396 .Attr("adj_x", adj_x) 397 .Attr("adj_y", adj_y) 398 .Finalize(&op.node_def)); 399 }; 400 401 set_adj(false, false); 402 403 // Rank checks. 404 INFER_ERROR("at least rank 2", op, "[1];?"); 405 INFER_ERROR("at least rank 2", op, "?;[2]"); 406 407 INFER_OK(op, "?;?", "?"); 408 409 // 0 batch dims. 410 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); 411 412 // 2 batch dims. 413 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]"); 414 415 // Test adj_a, testing output and that inner dims are compared. 416 set_adj(false, false); 417 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); 418 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch 419 set_adj(true, false); 420 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]"); 421 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch 422 423 // Test adj_b=true. 424 set_adj(false, true); 425 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]"); 426 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch 427 set_adj(true, true); 428 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]"); 429 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch 430 } 431 432 TEST(MathOpsTest, ArgOps_ShapeFn) { 433 ShapeInferenceTestOp op("ArgMax"); 434 op.input_tensors.resize(2); 435 436 INFER_OK(op, "?;?", "?"); 437 438 // input rank <= 1 produces scalar 439 INFER_OK(op, "[2];?", "[]"); 440 INFER_OK(op, "[];?", "[]"); 441 442 // Incorrect rank for dimension 443 INFER_ERROR("must be rank 0", op, "[2];[1]"); 444 445 // dimension not available, but input rank is. Output is unknown 446 // shape with rank one less than input rank. 447 INFER_OK(op, "[2,3,4];?", "[?,?]"); 448 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]"); 449 450 // Dimension values known 451 Tensor dimension = test::AsScalar(0); 452 op.input_tensors[1] = &dimension; 453 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]"); 454 455 dimension = test::AsScalar(1); 456 op.input_tensors[1] = &dimension; 457 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]"); 458 459 dimension = test::AsScalar(2); 460 op.input_tensors[1] = &dimension; 461 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 462 463 // Dimension value out of bounds 464 dimension = test::AsScalar(10); 465 op.input_tensors[1] = &dimension; 466 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]"); 467 468 dimension = test::AsScalar(-10); 469 op.input_tensors[1] = &dimension; 470 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]"); 471 472 dimension = test::AsScalar(-1); 473 op.input_tensors[1] = &dimension; 474 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 475 } 476 477 TEST(MathOpsTest, Betainc_ShapeFn) { 478 ShapeInferenceTestOp op("Betainc"); 479 480 INFER_OK(op, "?;?;?", "?"); 481 INFER_OK(op, "[?,?];?;?", "in0"); 482 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]"); 483 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]"); 484 485 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]"); 486 INFER_OK(op, "[];[];[?,?,3]", "in2"); 487 488 // All but one is a scalar, so use it. 489 INFER_OK(op, "[];[];?", "in2"); 490 INFER_OK(op, "[];[];[1,2,3,4]", "in2"); 491 492 // All scalar input; implementation picks in0. 493 INFER_OK(op, "[];[];[]", "in0"); 494 495 // Non-scalars must match shape. 496 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]"); 497 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); 498 } 499 500 TEST(MathOpsTest, Requantize_ShapeFn) { 501 ShapeInferenceTestOp op("Requantize"); 502 503 INFER_OK(op, "?;?;?;?;?", "in0;[];[]"); 504 INFER_OK(op, "?;[];[];[];[]", "in0;[];[]"); 505 506 // Rank checks on input scalars. 507 INFER_ERROR("must be rank 0", op, "?;[1];?;?;?"); 508 INFER_ERROR("must be rank 0", op, "?;?;[2];?;?"); 509 INFER_ERROR("must be rank 0", op, "?;?;?;[3];?"); 510 INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]"); 511 } 512 513 TEST(MathOpstest, RequantizationRange_ShapeFn) { 514 ShapeInferenceTestOp op("RequantizationRange"); 515 516 INFER_OK(op, "?;?;?", "[];[]"); 517 INFER_OK(op, "?;[];[]", "[];[]"); 518 519 // Rank checks on input scalars. 520 INFER_ERROR("must be rank 0", op, "?;[1];?"); 521 INFER_ERROR("must be rank 0", op, "?;?;[2]"); 522 } 523 524 TEST(MathOpsTest, Cross_ShapeFn) { 525 ShapeInferenceTestOp op("Cross"); 526 527 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]"); 528 INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]"); 529 INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]"); 530 531 INFER_OK(op, "?;?", "in0"); 532 INFER_OK(op, "[?];[?]", "in0"); 533 INFER_OK(op, "[1,?,3];[?,?,?]", "in0"); 534 } 535 536 TEST(MathOpsTest, HistogramFixedWidth_ShapeFn) { 537 ShapeInferenceTestOp op("HistogramFixedWidth"); 538 539 // value_range should be vector. 540 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];[];[]"); 541 // value_range should have 2 elements. 542 INFER_ERROR("Dimension must be 2 but is 3", op, "[];[3];[]"); 543 // nbins should be scalar. 544 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[2];[2]"); 545 546 INFER_OK(op, "?;?;?", "[?]"); 547 INFER_OK(op, "[?];[2];[]", "[?]"); 548 INFER_OK(op, "[?];[2];?", "[?]"); 549 } 550 551 TEST(MathOpsTest, QuantizedAdd_ShapeFn) { 552 ShapeInferenceTestOp op("QuantizedAdd"); 553 554 INFER_OK(op, "?;?;?;?;?;?", "?;[];[]"); 555 INFER_OK(op, "?;?;[];[];[];[]", "?;[];[]"); 556 INFER_OK(op, "[1,2];?;[];[];[];[]", "?;[];[]"); 557 INFER_OK(op, "[];[2];[];[];[];[]", "[d1_0];[];[]"); 558 559 // Rank checks on input scalars. 560 INFER_ERROR("must be rank 0", op, "?;?;[1];?;?;?"); 561 INFER_ERROR("must be rank 0", op, "?;?;?;[2];?;?"); 562 INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); 563 INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); 564 } 565 566 TEST(MathOpsTest, Bincount_ShapeFn) { 567 ShapeInferenceTestOp op("Bincount"); 568 569 // size should be scalar. 570 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); 571 572 INFER_OK(op, "?;?;?", "[?]"); 573 INFER_OK(op, "?;[];?", "[?]"); 574 INFER_OK(op, "[?];[];?", "[?]"); 575 INFER_OK(op, "[?];[];[?]", "[?]"); 576 } 577 } // end namespace tensorflow 578