1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 #include "tensorflow/core/framework/tensor.pb.h" 20 #include "tensorflow/core/util/mirror_pad_mode.h" 21 #include "tensorflow/core/util/padding.h" 22 #include "tensorflow/core/util/strided_slice_op.h" 23 #include "tensorflow/core/util/tensor_format.h" 24 25 namespace tensorflow { 26 27 using shape_inference::DimensionHandle; 28 using shape_inference::InferenceContext; 29 using shape_inference::ShapeHandle; 30 using shape_inference::UnchangedShape; 31 32 namespace { 33 34 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack, 35 int32* axis) { 36 TF_RETURN_IF_ERROR(c->GetAttr("axis", axis)); 37 if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) { 38 return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [", 39 -1 * rank_after_pack, ",", rank_after_pack, 40 ")"); 41 } 42 if (*axis < 0) *axis = (rank_after_pack + *axis); 43 return Status::OK(); 44 } 45 46 template <typename T> 47 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) { 48 std::vector<int64> ret(num_elements); 49 auto data = tensor->vec<T>(); 50 for (int64 i = 0; i < num_elements; ++i) { 51 ret[i] = data(i); 52 } 53 return ret; 54 } 55 56 template <typename T> 57 Status PadKnown(InferenceContext* c, ShapeHandle input, 58 const Tensor* paddings_t, int64 num_dims) { 59 // paddings_t is known. 60 std::vector<DimensionHandle> dims(num_dims); 61 auto paddings_data = paddings_t->matrix<T>(); 62 for (int64 i = 0; i < num_dims; ++i) { 63 const T pad0 = paddings_data(i, 0); 64 const T pad1 = paddings_data(i, 1); 65 if (pad0 < 0 || pad1 < 0) { 66 return errors::InvalidArgument("Paddings must be non-negative"); 67 } 68 TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i])); 69 } 70 c->set_output(0, c->MakeShape(dims)); 71 return Status::OK(); 72 } 73 74 Status PadShapeFn(InferenceContext* c) { 75 // Paddings is a matrix of [input_rank, 2]. 76 ShapeHandle paddings; 77 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); 78 DimensionHandle unused; 79 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused)); 80 81 // n_dim and input.rank are equivalent. 82 ShapeHandle input = c->input(0); 83 DimensionHandle n_dim = c->Dim(paddings, 0); 84 if (c->ValueKnown(n_dim)) { 85 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input)); 86 } else if (c->RankKnown(input)) { 87 TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim)); 88 } 89 90 const Tensor* paddings_t = c->input_tensor(1); 91 92 // paddings_t is unknown 93 if (paddings_t == nullptr) { 94 if (c->ValueKnown(n_dim)) { 95 // Make output with n_dim unknown dims. 96 c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim))); 97 } else { 98 c->set_output(0, c->UnknownShape()); 99 } 100 return Status::OK(); 101 } 102 103 const int64 num_dims = paddings_t->shape().dim_size(0); 104 TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input)); 105 TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim)); 106 107 if (paddings_t->dtype() == DT_INT32) { 108 return PadKnown<int32>(c, input, paddings_t, num_dims); 109 } else { 110 return PadKnown<int64>(c, input, paddings_t, num_dims); 111 } 112 } 113 114 Status TransposeShapeFn(InferenceContext* c) { 115 ShapeHandle input = c->input(0); 116 ShapeHandle perm_shape = c->input(1); 117 const Tensor* perm = c->input_tensor(1); 118 DimensionHandle perm_elems = c->NumElements(perm_shape); 119 // If we don't have rank information on the input or value information on 120 // perm we can't return any shape information, otherwise we have enough 121 // information to at least find the rank of the output. 122 if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) { 123 c->set_output(0, c->UnknownShape()); 124 return Status::OK(); 125 } 126 127 // Find our value of the rank. 128 int64 rank; 129 if (c->RankKnown(input)) { 130 rank = c->Rank(input); 131 } else if (c->ValueKnown(perm_elems)) { 132 rank = c->Value(perm_elems); 133 } else { 134 rank = perm->NumElements(); 135 } 136 if (!c->RankKnown(input) && rank < 2) { 137 // A permutation array containing a single element is ambiguous. It could 138 // indicate either a scalar or a 1-dimensional array, both of which the 139 // transpose op returns unchanged. 140 c->set_output(0, input); 141 return Status::OK(); 142 } 143 144 std::vector<DimensionHandle> dims; 145 dims.resize(rank); 146 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); 147 // Ensure that perm is a vector and has rank elements. 148 TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape)); 149 TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems)); 150 151 // If we know the rank of the input and the value of perm, we can return 152 // all shape informantion, otherwise we can only return rank information, 153 // but no information for the dimensions. 154 if (perm != nullptr) { 155 std::vector<int64> data; 156 if (perm->dtype() == DT_INT32) { 157 data = AsInt64<int32>(perm, rank); 158 } else { 159 data = AsInt64<int64>(perm, rank); 160 } 161 162 for (int32 i = 0; i < rank; ++i) { 163 int64 in_idx = data[i]; 164 if (in_idx >= rank) { 165 return errors::InvalidArgument("perm dim ", in_idx, 166 " is out of range of input rank ", rank); 167 } 168 dims[i] = c->Dim(input, in_idx); 169 } 170 } else { 171 for (int i = 0; i < rank; ++i) { 172 dims[i] = c->UnknownDim(); 173 } 174 } 175 176 c->set_output(0, c->MakeShape(dims)); 177 return Status::OK(); 178 } 179 180 Status SetOutputShapeForReshape(InferenceContext* c) { 181 ShapeHandle in = c->input(0); 182 ShapeHandle out; 183 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); 184 185 if (!c->RankKnown(out)) { 186 // We have no information about the shape of the output. 187 c->set_output(0, out); 188 return Status::OK(); 189 } 190 191 if (c->RankKnown(out) && c->RankKnown(in)) { 192 // We don't know the number of output elements, but we can try to infer 193 // the missing dimension. 194 bool too_many_unknown = false; 195 int32 out_unknown_idx = -1; 196 197 DimensionHandle known_out_elems = c->NumElements(out); 198 if (!c->ValueKnown(known_out_elems)) { 199 known_out_elems = c->MakeDim(1); 200 for (int32 i = 0; i < c->Rank(out); ++i) { 201 DimensionHandle dim = c->Dim(out, i); 202 if (!c->ValueKnown(dim)) { 203 if (out_unknown_idx >= 0) { 204 too_many_unknown = true; 205 break; 206 } 207 out_unknown_idx = i; 208 } else { 209 TF_RETURN_IF_ERROR( 210 c->Multiply(known_out_elems, dim, &known_out_elems)); 211 } 212 } 213 } 214 int32 in_unknown_idx = -1; 215 DimensionHandle known_in_elems = c->NumElements(in); 216 if (!c->ValueKnown(known_in_elems)) { 217 known_in_elems = c->MakeDim(1); 218 for (int32 i = 0; i < c->Rank(in); ++i) { 219 DimensionHandle dim = c->Dim(in, i); 220 if (!c->ValueKnown(dim)) { 221 if (in_unknown_idx >= 0) { 222 too_many_unknown = true; 223 break; 224 } 225 in_unknown_idx = i; 226 } else { 227 TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems)); 228 } 229 } 230 } 231 232 if (!too_many_unknown) { 233 if (in_unknown_idx < 0 && out_unknown_idx < 0) { 234 // Just check that the dimensions match. 235 if (c->Value(known_in_elems) != c->Value(known_out_elems)) { 236 return errors::InvalidArgument( 237 "Cannot reshape a tensor with ", c->DebugString(known_in_elems), 238 " elements to shape ", c->DebugString(out), " (", 239 c->DebugString(known_out_elems), " elements)"); 240 } 241 } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 && 242 c->Value(known_out_elems) > 0) { 243 // Input fully known, infer the one missing output dim 244 DimensionHandle inferred_dim; 245 TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems), 246 true /* evenly_divisible */, 247 &inferred_dim)); 248 TF_RETURN_IF_ERROR( 249 c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out)); 250 251 } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 && 252 c->Value(known_in_elems) != 0) { 253 // Output fully known, infer the one missing input dim 254 DimensionHandle inferred_dim; 255 TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems), 256 true /* evenly_divisible */, 257 &inferred_dim)); 258 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); 259 TF_RETURN_IF_ERROR( 260 c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim)); 261 } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) { 262 // Exactly one unknown dimension in both input and output. These 2 are 263 // equal iff the known elements are equal. 264 if (c->Value(known_in_elems) == c->Value(known_out_elems)) { 265 DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); 266 TF_RETURN_IF_ERROR( 267 c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out)); 268 } 269 } 270 } 271 } 272 c->set_output(0, out); 273 return Status::OK(); 274 } 275 276 } // namespace 277 278 REGISTER_OP("ParallelConcat") 279 .Input("values: N * T") 280 .Output("output: T") 281 .Attr("N: int >= 1") 282 .Attr("T: type") 283 .Attr("shape: shape") 284 .SetShapeFn([](InferenceContext* c) { 285 // Validate that the shape attr is correct. 286 PartialTensorShape shape; 287 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 288 ShapeHandle passed_shape; 289 TF_RETURN_IF_ERROR( 290 c->MakeShapeFromPartialTensorShape(shape, &passed_shape)); 291 if (!c->FullyDefined(passed_shape)) { 292 return errors::InvalidArgument("shape attr must be fully defined."); 293 } 294 ShapeHandle cur; 295 TF_RETURN_IF_ERROR(c->ReplaceDim( 296 passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)), 297 &cur)); 298 for (int i = 0; i < c->num_inputs(); ++i) { 299 if (!c->FullyDefined(c->input(i))) { 300 return errors::InvalidArgument( 301 "All input shapes must be fully defined."); 302 } 303 DimensionHandle unused; 304 if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) { 305 return errors::InvalidArgument("Size of first dimension must be 1."); 306 } 307 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 308 "From merging shape ", i, 309 " with other shapes."); 310 } 311 312 c->set_output(0, passed_shape); 313 314 return Status::OK(); 315 }); 316 317 REGISTER_OP("Pack") 318 .Input("values: N * T") 319 .Output("output: T") 320 .Attr("N: int >= 1") 321 .Attr("T: type") 322 .Attr("axis: int = 0") 323 .SetShapeFn([](InferenceContext* c) { 324 // Validate shapes of all inputs are compatible 325 ShapeHandle cur = c->input(c->num_inputs() - 1); 326 for (int i = c->num_inputs() - 2; i >= 0; --i) { 327 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 328 "From merging shape ", i, 329 " with other shapes."); 330 } 331 if (!c->RankKnown(cur)) { 332 c->set_output(0, c->UnknownShape()); 333 return Status::OK(); 334 } 335 // Determine the axis that will be added, converting from negative 336 // axes to a positive point per negative indexing rules. 337 int32 rank = c->Rank(cur); 338 int32 axis; 339 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis)); 340 341 // Copy all dimensions over, inserting a dimension of value #inputs 342 // at <axis>. 343 std::vector<DimensionHandle> dims; 344 int index = 0; 345 while (index < axis) dims.push_back(c->Dim(cur, index++)); 346 dims.push_back(c->MakeDim(c->num_inputs())); 347 while (index < rank) dims.push_back(c->Dim(cur, index++)); 348 349 c->set_output(0, c->MakeShape(dims)); 350 for (int i = 0; i < c->num_inputs(); ++i) { 351 auto* shape_and_type = c->input_handle_shapes_and_types(i); 352 if (shape_and_type) { 353 if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) { 354 c->set_output_handle_shapes_and_types( 355 0, std::vector<shape_inference::ShapeAndType>({})); 356 break; 357 } 358 } 359 } 360 return Status::OK(); 361 }); 362 363 REGISTER_OP("DeepCopy") 364 .Input("x: T") 365 .Output("y: T") 366 .Attr("T: type") 367 .SetIsStateful() 368 .SetShapeFn(UnchangedShape); 369 370 REGISTER_OP("InplaceUpdate") 371 .Input("x: T") 372 .Input("i: int32") 373 .Input("v: T") 374 .Output("y: T") 375 .Attr("T: type") 376 .SetShapeFn(UnchangedShape); 377 378 REGISTER_OP("InplaceAdd") 379 .Input("x: T") 380 .Input("i: int32") 381 .Input("v: T") 382 .Output("y: T") 383 .Attr("T: type") 384 .SetShapeFn(UnchangedShape); 385 386 REGISTER_OP("InplaceSub") 387 .Input("x: T") 388 .Input("i: int32") 389 .Input("v: T") 390 .Output("y: T") 391 .Attr("T: type") 392 .SetShapeFn(UnchangedShape); 393 394 REGISTER_OP("Empty") 395 .Input("shape: int32") 396 .Output("output: dtype") 397 .Attr("dtype: type") 398 .Attr("init: bool = false") 399 .SetIsStateful() 400 .SetShapeFn([](InferenceContext* c) { 401 ShapeHandle out; 402 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 403 c->set_output(0, out); 404 return Status::OK(); 405 }); 406 407 // -------------------------------------------------------------------------- 408 REGISTER_OP("Unpack") 409 .Input("value: T") 410 .Output("output: num * T") 411 .Attr("num: int >= 0") 412 .Attr("T: type") 413 .Attr("axis: int = 0") 414 .SetShapeFn([](InferenceContext* c) { 415 ShapeHandle s = c->input(0); 416 ShapeHandle out; 417 if (c->RankKnown(s)) { 418 // Determine the axis that will be removed, converting from negative 419 // axes to a positive point per negative indexing rules. 420 int32 rank = c->Rank(s); 421 int32 axis; 422 TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis)); 423 424 // The axis dim matches the number of outputs. 425 DimensionHandle unused; 426 TF_RETURN_IF_ERROR( 427 c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused)); 428 429 // Copy all dimensions, removing the <axis> dimension. 430 std::vector<DimensionHandle> dims; 431 for (int i = 0; i < rank; ++i) { 432 if (i != axis) dims.push_back(c->Dim(s, i)); 433 } 434 out = c->MakeShape(dims); 435 } else { 436 // All outputs are the same shape, but it's not known. 437 out = c->UnknownShape(); 438 } 439 for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out); 440 return Status::OK(); 441 }); 442 443 REGISTER_OP("UnravelIndex") 444 .Input("indices: Tidx") 445 .Input("dims: Tidx") 446 .Output("output: Tidx") 447 .Attr("Tidx: {int32, int64} = DT_INT32") 448 .SetShapeFn([](InferenceContext* c) { 449 ShapeHandle indices = c->input(0); 450 ShapeHandle dims; 451 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); 452 if (c->RankKnown(indices) && c->Rank(indices) == 0) { 453 c->set_output(0, c->Vector(c->Dim(dims, 0))); 454 } else if (c->RankKnown(indices)) { 455 c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices))); 456 } else { 457 c->set_output(0, c->UnknownShape()); 458 } 459 return Status::OK(); 460 }); 461 462 REGISTER_OP("BroadcastTo") 463 .Input("input: T") 464 .Input("shape: Tidx") 465 .Output("output: T") 466 .Attr("T: type") 467 .Attr("Tidx: {int32, int64} = DT_INT32") 468 .SetShapeFn([](InferenceContext* c) { 469 ShapeHandle shape_in = c->input(1); 470 TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in)); 471 ShapeHandle out; 472 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); 473 if (!c->RankKnown(out)) { 474 // We have no information about the shape of the output. 475 c->set_output(0, out); 476 return Status::OK(); 477 } 478 479 ShapeHandle in = c->input(0); 480 if (!c->RankKnown(in)) { 481 // We have no information about the shape of the input, 482 // nothing to do here. 483 c->set_output(0, out); 484 return Status::OK(); 485 } 486 int out_rank = c->Rank(out); 487 TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in)); 488 int in_rank = c->Rank(in); 489 for (int i = 0; i < in_rank; ++i) { 490 auto in_dim = c->Dim(in, in_rank - i - 1); 491 if (c->Value(in_dim) > 1) { 492 // If the input dimension is greater than 1 then the output dimension 493 // must be equal to it, since we only broadcast "from left to right". 494 auto out_dim = c->Dim(out, out_rank - i - 1); 495 TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim)); 496 TF_RETURN_IF_ERROR( 497 c->ReplaceDim(out, out_rank - i - 1, out_dim, &out)); 498 } 499 } 500 c->set_output(0, out); 501 return Status::OK(); 502 }); 503 504 // -------------------------------------------------------------------------- 505 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph 506 // in the N == 1 case to remove the node. 507 REGISTER_OP("Concat") 508 .Input("concat_dim: int32") 509 .Input("values: N * T") 510 .Output("output: T") 511 .Attr("N: int >= 2") 512 .Attr("T: type") 513 .SetShapeFn([](InferenceContext* c) { 514 return shape_inference::ConcatShape(c, c->num_inputs() - 1); 515 }); 516 517 REGISTER_OP("ConcatV2") 518 .Input("values: N * T") 519 .Input("axis: Tidx") 520 .Output("output: T") 521 .Attr("N: int >= 2") 522 .Attr("T: type") 523 .Attr("Tidx: {int32, int64} = DT_INT32") 524 .SetShapeFn(shape_inference::ConcatV2Shape); 525 526 // TODO(vivek.v.rane (at) intel.com): Prefix the op names with underscore if the ops 527 // are not to be made user-accessible. 528 #ifdef INTEL_MKL 529 REGISTER_OP("_MklConcatV2") 530 .Input("values: N * T") 531 .Input("axis: Tidx") 532 .Input("mkl_values: N * uint8") 533 .Input("mkl_axis: uint8") 534 .Output("output: T") 535 .Output("mkl_output: uint8") 536 .Attr("N: int >= 2") 537 .Attr("T: type") 538 .Attr("Tidx: {int32, int64} = DT_INT32") 539 .SetShapeFn(shape_inference::ConcatV2Shape) 540 .Doc(R"doc( 541 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation. 542 543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 544 expected to invoke these operators. 545 )doc"); 546 #endif 547 548 REGISTER_OP("ConcatOffset") 549 .Input("concat_dim: int32") 550 .Input("shape: N * int32") 551 .Output("offset: N * int32") 552 .Attr("N: int >= 2") 553 .SetShapeFn([](InferenceContext* c) { 554 for (int i = 1; i < c->num_inputs(); ++i) { 555 c->set_output(i - 1, c->input(i)); 556 } 557 return Status::OK(); 558 }); 559 560 // -------------------------------------------------------------------------- 561 REGISTER_OP("Split") 562 .Input("split_dim: int32") 563 .Input("value: T") 564 .Output("output: num_split * T") 565 .Attr("num_split: int >= 1") 566 .Attr("T: type") 567 .SetShapeFn([](InferenceContext* c) { 568 DimensionHandle split_dimension; 569 ShapeHandle input = c->input(1); 570 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( 571 0, c->Rank(input), &split_dimension)); 572 int num_split = c->num_outputs(); 573 ShapeHandle out; 574 if (!c->ValueKnown(split_dimension)) { 575 if (c->RankKnown(input)) { 576 out = c->UnknownShapeOfRank(c->Rank(input)); 577 } else { 578 out = c->UnknownShape(); 579 } 580 } else { 581 int64 split_dim = c->Value(split_dimension); 582 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); 583 DimensionHandle split_dim_size; 584 TF_RETURN_WITH_CONTEXT_IF_ERROR( 585 c->Divide(c->Dim(input, split_dim), num_split, 586 true /* evenly_divisible */, &split_dim_size), 587 "Number of ways to split should evenly divide the split dimension"); 588 TF_RETURN_IF_ERROR( 589 c->ReplaceDim(input, split_dim, split_dim_size, &out)); 590 } 591 for (int i = 0; i < num_split; ++i) c->set_output(i, out); 592 return Status::OK(); 593 }); 594 595 REGISTER_OP("SplitV") 596 .Input("value: T") 597 .Input("size_splits: Tlen") 598 .Input("split_dim: int32") 599 .Output("output: num_split * T") 600 .Attr("num_split: int >= 1") 601 .Attr("T: type") 602 .Attr("Tlen: {int32, int64} = DT_INT64") 603 .SetShapeFn([](InferenceContext* c) { 604 DimensionHandle split_dimension; 605 ShapeHandle input = c->input(0); 606 TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( 607 2, c->Rank(input), &split_dimension)); 608 int32 num_outputs = c->num_outputs(); 609 int32 rank = c->Rank(input); 610 ShapeHandle output_shape; 611 const Tensor* size_splits = c->input_tensor(1); 612 if (rank == InferenceContext::kUnknownRank) { 613 // If the rank of input tensor is unknown, then return unknown shapes. 614 // Note that the shape of each output can be different. 615 for (int i = 0; i < num_outputs; ++i) { 616 c->set_output(i, c->UnknownShape()); 617 } 618 } else if (rank == 0) { 619 // Throw error if input is a scalar. 620 return errors::InvalidArgument("Can't split scalars"); 621 } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) { 622 // If split dimension is known, but the sizes are unknown, then 623 // only the split dimension is unknown 624 output_shape = input; 625 for (int i = 0; i < num_outputs; ++i) { 626 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, 627 c->Value(split_dimension), 628 c->UnknownDim(), &output_shape)); 629 c->set_output(i, output_shape); 630 } 631 } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) { 632 // If split dimension or tensor containing the split sizes is unknown, 633 // then return unknown shapes of same rank as input. Note that each 634 // output shape can be different since splitv doesn't always split 635 // tensors evenly. 636 for (int i = 0; i < num_outputs; ++i) { 637 c->set_output(i, c->UnknownShapeOfRank(rank)); 638 } 639 } else { 640 // Determine the output shape if split dimension and split sizes are 641 // known. 642 int64 split_dim = c->Value(split_dimension); 643 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); 644 std::vector<int64> data; 645 if (size_splits->dtype() == DT_INT32) { 646 data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0)); 647 } else { 648 data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0)); 649 } 650 if (num_outputs != data.size()) { 651 return errors::InvalidArgument( 652 "Length of size_splits should be equal to num_outputs"); 653 } 654 int64_t total_size = 0; 655 bool has_neg_one = false; 656 for (const auto size : data) { 657 if (size == -1) { 658 if (has_neg_one) { 659 return errors::InvalidArgument( 660 "size_splits can only have one -1"); 661 } 662 has_neg_one = true; 663 } else { 664 total_size += size; 665 } 666 } 667 auto split_dim_size = c->Value(c->Dim(input, split_dim)); 668 // If the sizes of the splits are known, then 669 // make sure that the sizes add up to the expected 670 // dimension size, with the possibility of a -1. 671 // Specify the full output shapes. 672 for (int i = 0; i < num_outputs; ++i) { 673 auto size = data[i]; 674 if (data[i] == -1 && c->ValueKnown(split_dim_size)) { 675 size = split_dim_size - total_size; 676 } 677 TF_RETURN_IF_ERROR( 678 c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape)); 679 c->set_output(i, output_shape); 680 } 681 if (c->ValueKnown(split_dim_size)) { 682 if (has_neg_one ? total_size > split_dim_size 683 : total_size != split_dim_size) { 684 return errors::InvalidArgument( 685 "can't split axis of size ", split_dim_size, 686 " into pieces of size [", str_util::Join(data, ","), "]"); 687 } 688 } 689 } 690 691 return Status::OK(); 692 }); 693 694 // -------------------------------------------------------------------------- 695 REGISTER_OP("Const") 696 .Output("output: dtype") 697 .Attr("value: tensor") 698 .Attr("dtype: type") 699 .SetShapeFn([](InferenceContext* c) { 700 const TensorProto* proto = nullptr; 701 TF_RETURN_IF_ERROR(c->GetAttr("value", &proto)); 702 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape())); 703 TensorShape shape(proto->tensor_shape()); 704 std::vector<DimensionHandle> dims; 705 dims.reserve(shape.dims()); 706 for (int i = 0; i < shape.dims(); ++i) { 707 dims.push_back(c->MakeDim(shape.dim_size(i))); 708 } 709 c->set_output(0, c->MakeShape(dims)); 710 return Status::OK(); 711 }); 712 713 // Returns a constant tensor on the host. Useful for writing C++ tests 714 // and benchmarks which run on GPU but require arguments pinned to the host. 715 // Used by test::graph::HostConstant. 716 // value: Attr `value` is the tensor to return. 717 REGISTER_OP("HostConst") 718 .Output("output: dtype") 719 .Attr("value: tensor") 720 .Attr("dtype: type") 721 .SetShapeFn(shape_inference::UnknownShape); 722 723 // -------------------------------------------------------------------------- 724 // TODO(mgubin): Update the doc when the freeze_graph script supports converting 725 // into memmapped format. 726 REGISTER_OP("ImmutableConst") 727 .Attr("dtype: type") 728 .Attr("shape: shape") 729 .Attr("memory_region_name: string") 730 .Output("tensor: dtype") 731 .SetShapeFn(shape_inference::ExplicitShape); 732 733 REGISTER_OP("GuaranteeConst") 734 .Input("input: T") 735 .Output("output: T") 736 .Attr("T: type") 737 .SetShapeFn([](shape_inference::InferenceContext* c) { 738 return UnchangedShape(c); 739 }) 740 // We don't want this to be optimized away. 741 .SetIsStateful(); 742 743 // -------------------------------------------------------------------------- 744 REGISTER_OP("ZerosLike") 745 .Input("x: T") 746 .Output("y: T") 747 .Attr("T: type") 748 .SetShapeFn(shape_inference::UnchangedShape); 749 750 // -------------------------------------------------------------------------- 751 REGISTER_OP("OnesLike") 752 .Input("x: T") 753 .Output("y: T") 754 .Attr( 755 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, " 756 "int64, complex64, complex128, bool}") 757 .SetShapeFn(shape_inference::UnchangedShape); 758 759 // -------------------------------------------------------------------------- 760 REGISTER_OP("Diag") 761 .Input("diagonal: T") 762 .Output("output: T") 763 .Attr( 764 "T: {bfloat16, half, float, double, int32, int64, complex64, " 765 "complex128}") 766 .SetShapeFn([](InferenceContext* c) { 767 ShapeHandle in = c->input(0); 768 TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in)); 769 // Output shape is original concatenated with itself. 770 ShapeHandle out; 771 TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out)); 772 c->set_output(0, out); 773 return Status::OK(); 774 }); 775 776 // -------------------------------------------------------------------------- 777 REGISTER_OP("DiagPart") 778 .Input("input: T") 779 .Output("diagonal: T") 780 .Attr( 781 "T: {bfloat16, half, float, double, int32, int64, complex64, " 782 "complex128}") 783 .SetShapeFn([](InferenceContext* c) { 784 ShapeHandle in = c->input(0); 785 if (!c->RankKnown(in)) { 786 c->set_output(0, c->UnknownShape()); 787 return Status::OK(); 788 } 789 // Rank must be even, and result will have rank <rank/2>. 790 const int32 rank = c->Rank(in); 791 if ((rank % 2) != 0 || rank <= 0) { 792 return errors::InvalidArgument( 793 "Input must have even and non-zero rank, input rank is ", rank); 794 } 795 const int32 mid = rank / 2; 796 797 // output dim[i] is the merge of in.dim[i] and in.dim[i+mid]. 798 std::vector<DimensionHandle> dims(mid); 799 for (int i = 0; i < mid; ++i) { 800 TF_RETURN_IF_ERROR( 801 c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i])); 802 } 803 c->set_output(0, c->MakeShape(dims)); 804 return Status::OK(); 805 }); 806 807 // -------------------------------------------------------------------------- 808 REGISTER_OP("MatrixDiag") 809 .Input("diagonal: T") 810 .Output("output: T") 811 .Attr("T: type") 812 .SetShapeFn([](InferenceContext* c) { 813 ShapeHandle in; 814 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); 815 if (!c->RankKnown(in)) { 816 c->set_output(0, c->UnknownShape()); 817 return Status::OK(); 818 } 819 const int32 rank = c->Rank(in); 820 ShapeHandle out; 821 TF_RETURN_IF_ERROR( 822 c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out)); 823 c->set_output(0, out); 824 return Status::OK(); 825 }); 826 827 // -------------------------------------------------------------------------- 828 REGISTER_OP("MatrixSetDiag") 829 .Input("input: T") 830 .Input("diagonal: T") 831 .Output("output: T") 832 .Attr("T: type") 833 .SetShapeFn([](InferenceContext* c) { 834 ShapeHandle input; 835 ShapeHandle diag; 836 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 837 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag)); 838 if (c->RankKnown(input)) { 839 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag)); 840 } 841 DimensionHandle smallest_dim; 842 TF_RETURN_IF_ERROR( 843 c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim)); 844 TF_RETURN_IF_ERROR( 845 c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim)); 846 847 ShapeHandle output = input; 848 if (c->RankKnown(diag) && !c->FullyDefined(input)) { 849 // Try to infer parts of shape from diag. 850 ShapeHandle diag_prefix; 851 TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix)); 852 TF_RETURN_IF_ERROR( 853 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag)); 854 TF_RETURN_IF_ERROR(c->Merge(input, diag, &output)); 855 } 856 c->set_output(0, output); 857 return Status::OK(); 858 }); 859 860 // -------------------------------------------------------------------------- 861 REGISTER_OP("MatrixDiagPart") 862 .Input("input: T") 863 .Output("diagonal: T") 864 .Attr("T: type") 865 .SetShapeFn([](InferenceContext* c) { 866 ShapeHandle in; 867 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in)); 868 if (!c->RankKnown(in)) { 869 c->set_output(0, c->UnknownShape()); 870 return Status::OK(); 871 } 872 const int32 rank = c->Rank(in); 873 std::vector<DimensionHandle> dims; 874 dims.reserve(rank - 2); 875 for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i)); 876 877 DimensionHandle min_dim; 878 TF_RETURN_IF_ERROR( 879 c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim)); 880 dims.push_back(min_dim); 881 c->set_output(0, c->MakeShape(dims)); 882 return Status::OK(); 883 }); 884 885 // -------------------------------------------------------------------------- 886 REGISTER_OP("MatrixBandPart") 887 .Input("input: T") 888 .Input("num_lower: Tindex") 889 .Input("num_upper: Tindex") 890 .Output("band: T") 891 .Attr("T: type") 892 .Attr("Tindex: {int32, int64} = DT_INT64") 893 .SetShapeFn(shape_inference::UnchangedShape); 894 895 // -------------------------------------------------------------------------- 896 REGISTER_OP("Reverse") 897 .Input("tensor: T") 898 .Input("dims: bool") 899 .Output("output: T") 900 .Attr( 901 "T: {uint8, int8, uint16, int16, int32, int64, bool, half, " 902 "float, double, complex64, complex128, string}") 903 .SetShapeFn([](InferenceContext* c) { 904 ShapeHandle input = c->input(0); 905 ShapeHandle dims; 906 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims)); 907 DimensionHandle dims_dim = c->Dim(dims, 0); 908 if (c->ValueKnown(dims_dim)) { 909 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input)); 910 } 911 if (c->Rank(input) > 8) { 912 return errors::InvalidArgument( 913 "reverse does not work on tensors with more than 8 dimensions"); 914 } 915 c->set_output(0, input); 916 return Status::OK(); 917 }); 918 919 // -------------------------------------------------------------------------- 920 REGISTER_OP("ReverseV2") 921 .Input("tensor: T") 922 .Input("axis: Tidx") 923 .Output("output: T") 924 .Attr("Tidx: {int32, int64} = DT_INT32") 925 .Attr( 926 "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, " 927 "float, double, complex64, complex128, string}") 928 .SetShapeFn([](InferenceContext* c) { 929 ShapeHandle input = c->input(0); 930 ShapeHandle axis; 931 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis)); 932 if (c->Rank(input) > 8) { 933 return errors::InvalidArgument( 934 "reverse does not work on tensors with more than 8 dimensions"); 935 } 936 const Tensor* axis_tensor = c->input_tensor(1); 937 if (axis_tensor != nullptr && c->RankKnown(input)) { 938 int32 rank = c->Rank(input); 939 std::vector<int64> axis_value; 940 if (axis_tensor->dtype() == DT_INT32) { 941 axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements()); 942 } else { 943 axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements()); 944 } 945 std::vector<bool> axes_dense(c->Rank(input), false); 946 for (int i = 0; i < axis_value.size(); i++) { 947 int64 canonical_axis = 948 axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i]; 949 if (canonical_axis < 0 || canonical_axis >= rank) { 950 return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i], 951 " is out of valid range [", 0, ", ", 952 rank - 1); 953 } 954 if (axes_dense[canonical_axis]) { 955 return errors::InvalidArgument("axis ", canonical_axis, 956 " specified more than once."); 957 } 958 axes_dense[canonical_axis] = true; 959 } 960 } 961 c->set_output(0, input); 962 return Status::OK(); 963 }); 964 965 // -------------------------------------------------------------------------- 966 REGISTER_OP("EditDistance") 967 .Input("hypothesis_indices: int64") 968 .Input("hypothesis_values: T") 969 .Input("hypothesis_shape: int64") 970 .Input("truth_indices: int64") 971 .Input("truth_values: T") 972 .Input("truth_shape: int64") 973 .Attr("normalize: bool = true") 974 .Attr("T: type") 975 .Output("output: float") 976 .SetShapeFn([](InferenceContext* c) { 977 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( 978 c, c->input(0), c->input(1), c->input(2))); 979 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( 980 c, c->input(3), c->input(4), c->input(5))); 981 const Tensor* hypothesis_shape_t = c->input_tensor(2); 982 const Tensor* truth_shape_t = c->input_tensor(5); 983 if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) { 984 // We need to know the runtime shape of the two tensors, 985 // or else the output shape is unknown. 986 return shape_inference::UnknownShape(c); 987 } 988 989 if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) { 990 return errors::InvalidArgument( 991 "Num elements of hypothesis_shape does not match truth_shape: ", 992 hypothesis_shape_t->NumElements(), " vs. ", 993 truth_shape_t->NumElements()); 994 } 995 996 auto h_values = hypothesis_shape_t->flat<int64>(); 997 auto t_values = truth_shape_t->flat<int64>(); 998 std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1); 999 for (int i = 0; i < dims.size(); ++i) { 1000 dims[i] = c->MakeDim(std::max(h_values(i), t_values(i))); 1001 } 1002 1003 c->set_output(0, c->MakeShape(dims)); 1004 return Status::OK(); 1005 }); 1006 1007 // -------------------------------------------------------------------------- 1008 REGISTER_OP("Fill") 1009 .Input("dims: index_type") 1010 .Input("value: T") 1011 .Output("output: T") 1012 .Attr("T: type") 1013 .Attr("index_type: {int32, int64} = DT_INT32") 1014 .SetShapeFn([](InferenceContext* c) { 1015 DataType index_type = DT_INT32; 1016 Status s = c->GetAttr("index_type", &index_type); 1017 if (!s.ok() && s.code() != error::NOT_FOUND) { 1018 return s; 1019 } 1020 ShapeHandle unused; 1021 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 1022 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1023 1024 const Tensor* t = c->input_tensor(0); 1025 if (t != nullptr) { 1026 for (int i = 0; i < t->NumElements(); ++i) { 1027 if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) || 1028 (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) { 1029 return errors::InvalidArgument("Fill dimensions must be >= 0"); 1030 } 1031 } 1032 } 1033 1034 ShapeHandle out; 1035 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 1036 c->set_output(0, out); 1037 1038 auto* shape_and_type = c->input_handle_shapes_and_types(1); 1039 if (shape_and_type) { 1040 c->set_output_handle_shapes_and_types(0, *shape_and_type); 1041 } 1042 1043 return Status::OK(); 1044 }); 1045 1046 // -------------------------------------------------------------------------- 1047 REGISTER_OP("_ParallelConcatStart") 1048 .Output("output: dtype") 1049 .Attr("shape: shape") 1050 .Attr("dtype: type") 1051 .SetIsStateful() 1052 .SetShapeFn(shape_inference::ExplicitShape) 1053 .Doc(R"doc( 1054 Creates an empty Tensor with shape `shape` and type `dtype`. 1055 1056 The memory can optionally be initialized. This is usually useful in 1057 conjunction with inplace operations. 1058 1059 shape: 1-D `Tensor` indicating the shape of the output. 1060 dtype: The element type of the returned tensor. 1061 output: An empty Tensor of the specified type. 1062 )doc"); 1063 1064 // -------------------------------------------------------------------------- 1065 REGISTER_OP("_ParallelConcatUpdate") 1066 .Input("value: T") 1067 .Input("update: T") 1068 .Output("output: T") 1069 .Attr("T: type") 1070 .Attr("loc: int") 1071 .SetShapeFn(shape_inference::UnchangedShape) 1072 .Doc(R"doc( 1073 Updates input `value` at `loc` with `update`. 1074 1075 If you use this function you will almost certainly want to add 1076 a control dependency as done in the implementation of parallel_stack to 1077 avoid race conditions. 1078 1079 value: A `Tensor` object that will be updated in-place. 1080 loc: A scalar indicating the index of the first dimension such that 1081 value[loc, :] is updated. 1082 update: A `Tensor` of rank one less than `value` if `loc` is a scalar, 1083 otherwise of rank equal to `value` that contains the new values 1084 for `value`. 1085 output: `value` that has been updated accordingly. 1086 )doc"); 1087 1088 // -------------------------------------------------------------------------- 1089 REGISTER_OP("Gather") 1090 .Input("params: Tparams") 1091 .Input("indices: Tindices") 1092 .Attr("validate_indices: bool = true") 1093 .Output("output: Tparams") 1094 .Attr("Tparams: type") 1095 .Attr("Tindices: {int32,int64}") 1096 .SetShapeFn([](InferenceContext* c) { 1097 ShapeHandle unused; 1098 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); 1099 ShapeHandle params_subshape; 1100 TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, ¶ms_subshape)); 1101 ShapeHandle indices_shape = c->input(1); 1102 ShapeHandle out; 1103 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out)); 1104 c->set_output(0, out); 1105 return Status::OK(); 1106 }); 1107 1108 // -------------------------------------------------------------------------- 1109 REGISTER_OP("GatherV2") 1110 .Input("params: Tparams") 1111 .Input("indices: Tindices") 1112 .Input("axis: Taxis") 1113 .Output("output: Tparams") 1114 .Attr("Tparams: type") 1115 .Attr("Tindices: {int32,int64}") 1116 .Attr("Taxis: {int32,int64}") 1117 .SetShapeFn([](InferenceContext* c) { 1118 ShapeHandle params_shape; 1119 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, ¶ms_shape)); 1120 1121 ShapeHandle indices_shape = c->input(1); 1122 ShapeHandle unused_axis_shape; 1123 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape)); 1124 const Tensor* axis_t = c->input_tensor(2); 1125 1126 // If axis is unknown, we can only infer that the result is params_rank + 1127 // indices_rank - 1. 1128 if (axis_t == nullptr) { 1129 if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) { 1130 c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) + 1131 c->Rank(indices_shape) - 1)); 1132 } else { 1133 c->set_output(0, c->UnknownShape()); 1134 } 1135 return Status::OK(); 1136 } 1137 1138 // Note, axis can be negative. 1139 int64 axis = 0; 1140 if (axis_t->dtype() == DT_INT32) { 1141 axis = axis_t->scalar<int32>()(); 1142 } else { 1143 axis = axis_t->scalar<int64>()(); 1144 } 1145 1146 // Check that params has rank of at least axis + 1. 1147 ShapeHandle unused; 1148 TF_RETURN_IF_ERROR(c->WithRankAtLeast( 1149 params_shape, axis < 0 ? -axis : axis + 1, &unused)); 1150 1151 ShapeHandle params_outer_subshape; 1152 TF_RETURN_IF_ERROR( 1153 c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape)); 1154 1155 ShapeHandle out; 1156 TF_RETURN_IF_ERROR( 1157 c->Concatenate(params_outer_subshape, indices_shape, &out)); 1158 1159 // Slice from axis + 1 to the end of params_shape to collect the inner 1160 // dimensions of the result. Special case -1 here since -1 + 1 wraps, and 1161 // we slice from 0 to the end of shape. Subshape() handles all other 1162 // out-of-bounds checking. 1163 if (axis != -1) { 1164 ShapeHandle params_inner_subshape; 1165 TF_RETURN_IF_ERROR( 1166 c->Subshape(params_shape, axis + 1, ¶ms_inner_subshape)); 1167 TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out)); 1168 } 1169 1170 c->set_output(0, out); 1171 return Status::OK(); 1172 }); 1173 1174 // -------------------------------------------------------------------------- 1175 REGISTER_OP("GatherNd") 1176 .Input("params: Tparams") 1177 .Input("indices: Tindices") 1178 .Output("output: Tparams") 1179 .Attr("Tparams: type") 1180 .Attr("Tindices: {int32,int64}") 1181 .SetShapeFn([](InferenceContext* c) { 1182 ShapeHandle params = c->input(0); 1183 ShapeHandle indices; 1184 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices)); 1185 DimensionHandle r_dim = c->Dim(indices, -1); 1186 1187 if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { 1188 c->set_output(0, c->UnknownShape()); 1189 return Status::OK(); 1190 } 1191 1192 if (c->Value(r_dim) > c->Rank(params)) { 1193 return errors::InvalidArgument( 1194 "indices.shape[-1] must be <= params.rank, but saw indices shape: ", 1195 c->DebugString(indices), 1196 " and params shape: ", c->DebugString(params)); 1197 } 1198 1199 // Remove r_dim from indices to get output. 1200 ShapeHandle indices_slice; 1201 ShapeHandle params_slice; 1202 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice)); 1203 TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice)); 1204 ShapeHandle out; 1205 TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); 1206 c->set_output(0, out); 1207 return Status::OK(); 1208 }); 1209 1210 // -------------------------------------------------------------------------- 1211 REGISTER_OP("Identity") 1212 .Input("input: T") 1213 .Output("output: T") 1214 .Attr("T: type") 1215 .SetShapeFn(shape_inference::UnchangedShape); 1216 1217 REGISTER_OP("Snapshot") 1218 .Input("input: T") 1219 .Output("output: T") 1220 .Attr("T: type") 1221 .SetShapeFn(shape_inference::UnchangedShape); 1222 1223 #ifdef INTEL_MKL 1224 REGISTER_OP("_MklIdentity") 1225 .Input("input: T") 1226 .Input("mkl_input: uint8") 1227 .Output("output: T") 1228 .Output("mkl_output: uint8") 1229 .Attr("T: type") 1230 .SetShapeFn(shape_inference::UnchangedShape) 1231 .Doc(R"Doc( Mkl implementation of IdentityOp 1232 )Doc"); 1233 #endif 1234 1235 REGISTER_OP("IdentityN") 1236 .Input("input: T") 1237 .Output("output: T") 1238 .Attr("T: list(type)") 1239 .SetShapeFn([](shape_inference::InferenceContext* c) { 1240 std::vector<ShapeHandle> input; 1241 TF_RETURN_IF_ERROR(c->input("input", &input)); 1242 TF_RETURN_IF_ERROR(c->set_output("output", input)); 1243 return Status::OK(); 1244 }); 1245 1246 // -------------------------------------------------------------------------- 1247 REGISTER_OP("RefIdentity") 1248 .Input("input: Ref(T)") 1249 .Output("output: Ref(T)") 1250 .Attr("T: type") 1251 .SetShapeFn(shape_inference::UnchangedShape) 1252 .SetAllowsUninitializedInput(); 1253 1254 // -------------------------------------------------------------------------- 1255 REGISTER_OP("DebugGradientIdentity") 1256 .Input("input: T") 1257 .Output("output: T") 1258 .Attr("T: type") 1259 .SetShapeFn(shape_inference::UnchangedShape) 1260 .SetAllowsUninitializedInput(); 1261 1262 REGISTER_OP("DebugGradientRefIdentity") 1263 .Input("input: Ref(T)") 1264 .Output("output: Ref(T)") 1265 .Attr("T: type") 1266 .SetShapeFn(shape_inference::UnchangedShape) 1267 .SetAllowsUninitializedInput(); 1268 1269 // -------------------------------------------------------------------------- 1270 REGISTER_OP("StopGradient") 1271 .Input("input: T") 1272 .Output("output: T") 1273 .Attr("T: type") 1274 .SetShapeFn(shape_inference::UnchangedShape); 1275 1276 REGISTER_OP("PreventGradient") 1277 .Input("input: T") 1278 .Output("output: T") 1279 .Attr("T: type") 1280 .Attr("message: string = ''") 1281 .SetShapeFn(shape_inference::UnchangedShape); 1282 1283 // -------------------------------------------------------------------------- 1284 REGISTER_OP("CheckNumerics") 1285 .Input("tensor: T") 1286 .Output("output: T") 1287 .Attr("T: {bfloat16, half, float, double}") 1288 .Attr("message: string") 1289 .SetShapeFn(shape_inference::UnchangedShape); 1290 1291 // -------------------------------------------------------------------------- 1292 REGISTER_OP("Reshape") 1293 .Input("tensor: T") 1294 .Input("shape: Tshape") 1295 .Output("output: T") 1296 .Attr("T: type") 1297 .Attr("Tshape: {int32, int64} = DT_INT32") 1298 .SetShapeFn([](InferenceContext* c) { 1299 return SetOutputShapeForReshape(c); 1300 }); 1301 1302 #ifdef INTEL_MKL 1303 REGISTER_OP("_MklReshape") 1304 .Input("tensor: T") 1305 .Input("shape: Tshape") 1306 .Input("mkl_tensor: uint8") 1307 .Input("mkl_shape: uint8") 1308 .Output("output: T") 1309 .Output("mkl_output: uint8") 1310 .Attr("T: type") 1311 .Attr("Tshape: {int32, int64} = DT_INT32") 1312 .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); }) 1313 .Doc(R"Doc( MKL implementation of ReshapeOp. 1314 )Doc"); 1315 #endif // INTEL_MKL 1316 1317 // -------------------------------------------------------------------------- 1318 REGISTER_OP("InvertPermutation") 1319 .Input("x: T") 1320 .Output("y: T") 1321 .Attr("T: {int32, int64} = DT_INT32") 1322 .SetShapeFn([](InferenceContext* c) { 1323 ShapeHandle x; 1324 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x)); 1325 c->set_output(0, x); 1326 return Status::OK(); 1327 }); 1328 1329 // -------------------------------------------------------------------------- 1330 REGISTER_OP("Transpose") 1331 .Input("x: T") 1332 .Input("perm: Tperm") 1333 .Output("y: T") 1334 .Attr("T: type") 1335 .Attr("Tperm: {int32, int64} = DT_INT32") 1336 .SetShapeFn(TransposeShapeFn); 1337 1338 // -------------------------------------------------------------------------- 1339 REGISTER_OP("ConjugateTranspose") 1340 .Input("x: T") 1341 .Input("perm: Tperm") 1342 .Output("y: T") 1343 .Attr("T: type") 1344 .Attr("Tperm: {int32, int64} = DT_INT32") 1345 .SetShapeFn(TransposeShapeFn); 1346 1347 // -------------------------------------------------------------------------- 1348 REGISTER_OP("Unique") 1349 .Input("x: T") 1350 .Output("y: T") 1351 .Output("idx: out_idx") 1352 .Attr("T: type") 1353 .Attr("out_idx: {int32, int64} = DT_INT32") 1354 .SetShapeFn([](InferenceContext* c) { 1355 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1356 c->set_output(1, c->input(0)); 1357 // Assert that the input rank is 1. 1358 ShapeHandle dummy; 1359 return c->WithRank(c->input(0), 1, &dummy); 1360 }); 1361 1362 REGISTER_OP("UniqueV2") 1363 .Input("x: T") 1364 .Input("axis: Taxis") 1365 .Output("y: T") 1366 .Output("idx: out_idx") 1367 .Attr("T: type") 1368 .Attr("Taxis: {int32,int64} = DT_INT64") 1369 .Attr("out_idx: {int32, int64} = DT_INT32") 1370 .SetShapeFn([](InferenceContext* c) { 1371 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1372 c->set_output(1, c->input(0)); 1373 return Status::OK(); 1374 }); 1375 1376 // -------------------------------------------------------------------------- 1377 REGISTER_OP("UniqueWithCounts") 1378 .Input("x: T") 1379 .Output("y: T") 1380 .Output("idx: out_idx") 1381 .Output("count: out_idx") 1382 .Attr("T: type") 1383 .Attr("out_idx: {int32, int64} = DT_INT32") 1384 .SetShapeFn([](InferenceContext* c) { 1385 auto uniq = c->Vector(InferenceContext::kUnknownDim); 1386 c->set_output(0, uniq); 1387 c->set_output(1, c->input(0)); 1388 c->set_output(2, uniq); 1389 return Status::OK(); 1390 }); 1391 1392 REGISTER_OP("UniqueWithCountsV2") 1393 .Input("x: T") 1394 .Input("axis: Taxis") 1395 .Output("y: T") 1396 .Output("idx: out_idx") 1397 .Output("count: out_idx") 1398 .Attr("T: type") 1399 .Attr("Taxis: {int32,int64} = DT_INT64") 1400 .Attr("out_idx: {int32, int64} = DT_INT32") 1401 .SetShapeFn([](InferenceContext* c) { 1402 auto uniq = c->Vector(InferenceContext::kUnknownDim); 1403 c->set_output(0, uniq); 1404 c->set_output(1, c->input(0)); 1405 c->set_output(2, uniq); 1406 return Status::OK(); 1407 }); 1408 1409 namespace { 1410 1411 Status ShapeShapeFn(InferenceContext* c) { 1412 for (int i = 0; i < c->num_inputs(); ++i) { 1413 DimensionHandle dim; 1414 if (c->RankKnown(c->input(i))) { 1415 dim = c->MakeDim(c->Rank(c->input(i))); 1416 } else { 1417 dim = c->UnknownDim(); 1418 } 1419 c->set_output(i, c->Vector(dim)); 1420 } 1421 return Status::OK(); 1422 } 1423 1424 } // namespace 1425 1426 // -------------------------------------------------------------------------- 1427 REGISTER_OP("Shape") 1428 .Input("input: T") 1429 .Output("output: out_type") 1430 .Attr("T: type") 1431 .Attr("out_type: {int32, int64} = DT_INT32") 1432 .SetShapeFn(ShapeShapeFn); 1433 1434 REGISTER_OP("ShapeN") 1435 .Input("input: N * T") 1436 .Output("output: N * out_type") 1437 .Attr("N: int") 1438 .Attr("T: type") 1439 .Attr("out_type: {int32, int64} = DT_INT32") 1440 .SetShapeFn(ShapeShapeFn); 1441 1442 REGISTER_OP("EnsureShape") 1443 .Input("input: T") 1444 .Output("output: T") 1445 .Attr("shape: shape") 1446 .Attr("T: type") 1447 .SetShapeFn([](InferenceContext* c) { 1448 // Merges desired shape and statically known shape of input 1449 PartialTensorShape desired_shape; 1450 TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); 1451 1452 int rank = desired_shape.dims(); 1453 ShapeHandle input_shape_handle; 1454 ShapeHandle desired_shape_handle; 1455 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle)); 1456 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( 1457 desired_shape, &desired_shape_handle)); 1458 1459 ShapeHandle merged_shape; 1460 TF_RETURN_IF_ERROR( 1461 c->Merge(desired_shape_handle, input_shape_handle, &merged_shape)); 1462 c->set_output(0, merged_shape); 1463 return Status::OK(); 1464 }); 1465 1466 // -------------------------------------------------------------------------- 1467 REGISTER_OP("ReverseSequence") 1468 .Input("input: T") 1469 .Input("seq_lengths: Tlen") 1470 .Output("output: T") 1471 .Attr("seq_dim: int") 1472 .Attr("batch_dim: int = 0") 1473 .Attr("T: type") 1474 .Attr("Tlen: {int32, int64} = DT_INT64") 1475 .SetShapeFn([](InferenceContext* c) { 1476 ShapeHandle input = c->input(0); 1477 ShapeHandle seq_lens_shape; 1478 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape)); 1479 1480 int64 seq_dim; 1481 TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim)); 1482 int64 batch_dim; 1483 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim)); 1484 1485 if (!c->RankKnown(input)) { 1486 return shape_inference::UnknownShape(c); 1487 } 1488 1489 // Validate batch_dim and seq_dim against input. 1490 const int32 input_rank = c->Rank(input); 1491 if (batch_dim >= input_rank) { 1492 return errors::InvalidArgument( 1493 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank); 1494 } 1495 if (seq_dim >= input_rank) { 1496 return errors::InvalidArgument( 1497 "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank); 1498 } 1499 1500 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); 1501 TF_RETURN_IF_ERROR( 1502 c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim)); 1503 1504 // Replace batch_dim of input with batch_size 1505 ShapeHandle output_shape; 1506 TF_RETURN_IF_ERROR( 1507 c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape)); 1508 c->set_output(0, output_shape); 1509 return Status::OK(); 1510 }); 1511 1512 // -------------------------------------------------------------------------- 1513 REGISTER_OP("Rank") 1514 .Input("input: T") 1515 .Output("output: int32") 1516 .Attr("T: type") 1517 .SetShapeFn(shape_inference::ScalarShape); 1518 1519 // -------------------------------------------------------------------------- 1520 REGISTER_OP("Size") 1521 .Input("input: T") 1522 .Output("output: out_type") 1523 .Attr("T: type") 1524 .Attr("out_type: {int32, int64} = DT_INT32") 1525 .SetShapeFn(shape_inference::ScalarShape); 1526 1527 // -------------------------------------------------------------------------- 1528 REGISTER_OP("Slice") 1529 .Input("input: T") 1530 .Input("begin: Index") 1531 .Input("size: Index") 1532 .Output("output: T") 1533 .Attr("T: type") 1534 .Attr("Index: {int32,int64}") 1535 .SetShapeFn(shape_inference::SliceShape); 1536 1537 #ifdef INTEL_MKL 1538 REGISTER_OP("_MklSlice") 1539 .Input("input: T") 1540 .Input("begin: Index") 1541 .Input("size: Index") 1542 .Input("mkl_input: uint8") 1543 .Input("mkl_begin: uint8") 1544 .Input("mkl_size: uint8") 1545 .Output("output: T") 1546 .Output("mkl_output: uint8") 1547 .Attr("T: type") 1548 .Attr("Index: {int32,int64}") 1549 .SetShapeFn(shape_inference::SliceShape); 1550 #endif 1551 1552 REGISTER_OP("StridedSlice") 1553 .Input("input: T") 1554 .Input("begin: Index") 1555 .Input("end: Index") 1556 .Input("strides: Index") 1557 .Output("output: T") 1558 .Attr("T: type") 1559 .Attr("Index: {int32, int64}") 1560 .Attr("begin_mask: int = 0") 1561 .Attr("end_mask: int = 0") 1562 .Attr("ellipsis_mask: int = 0") 1563 .Attr("new_axis_mask: int = 0") 1564 .Attr("shrink_axis_mask: int = 0") 1565 .SetShapeFn([](InferenceContext* c) { 1566 ShapeHandle input = c->input(0); 1567 ShapeHandle begin_shape, end_shape, strides_shape; 1568 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); 1569 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape)); 1570 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape)); 1571 TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape)); 1572 TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape)); 1573 DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0); 1574 1575 const Tensor* strides_value = c->input_tensor(3); 1576 // TODO(aselle,allenl): If we had a stride_mask it would be possible to do 1577 // more shape inference here (e.g. for x[3, ::T]). 1578 if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) || 1579 strides_value == nullptr) { 1580 c->set_output(0, c->UnknownShape()); 1581 return Status::OK(); 1582 } 1583 1584 PartialTensorShape input_shape({}); 1585 for (int i = 0; i < c->Rank(input); ++i) { 1586 auto dim = c->Dim(input, i); 1587 input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1); 1588 } 1589 1590 int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, 1591 shrink_axis_mask; 1592 TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask)); 1593 TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask)); 1594 TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask)); 1595 TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask)); 1596 TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask)); 1597 1598 const Tensor* begin_value = c->input_tensor(1); 1599 const Tensor* end_value = c->input_tensor(2); 1600 1601 PartialTensorShape processing_shape, final_shape; 1602 bool is_identity, is_simple_slice, slice_dim0; 1603 gtl::InlinedVector<int64, 4> begin, end, strides; 1604 TF_RETURN_IF_ERROR(ValidateStridedSliceOp( 1605 begin_value, end_value, *strides_value, input_shape, begin_mask, 1606 end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, 1607 &processing_shape, &final_shape, &is_identity, &is_simple_slice, 1608 &slice_dim0, &begin, &end, &strides)); 1609 1610 ShapeHandle out; 1611 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out)); 1612 c->set_output(0, out); 1613 1614 auto* shape_and_type = c->input_handle_shapes_and_types(0); 1615 if (shape_and_type) { 1616 c->set_output_handle_shapes_and_types(0, *shape_and_type); 1617 } 1618 1619 return Status::OK(); 1620 }); 1621 1622 REGISTER_OP("StridedSliceGrad") 1623 .Input("shape: Index") 1624 .Input("begin: Index") 1625 .Input("end: Index") 1626 .Input("strides: Index") 1627 .Input("dy: T") 1628 .Output("output: T") 1629 .Attr("T: type") 1630 .Attr("Index: {int32, int64}") 1631 .Attr("begin_mask: int = 0") 1632 .Attr("end_mask: int = 0") 1633 .Attr("ellipsis_mask: int = 0") 1634 .Attr("new_axis_mask: int = 0") 1635 .Attr("shrink_axis_mask: int = 0") 1636 .SetShapeFn([](InferenceContext* c) { 1637 ShapeHandle out; 1638 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 1639 c->set_output(0, out); 1640 return Status::OK(); 1641 }); 1642 1643 REGISTER_OP("StridedSliceAssign") 1644 .Input("ref: Ref(T)") 1645 .Input("begin: Index") 1646 .Input("end: Index") 1647 .Input("strides: Index") 1648 .Input("value: T") 1649 .Output("output_ref: Ref(T)") 1650 .Attr("T: type") 1651 .Attr("Index: {int32, int64}") 1652 .Attr("begin_mask: int = 0") 1653 .Attr("end_mask: int = 0") 1654 .Attr("ellipsis_mask: int = 0") 1655 .Attr("new_axis_mask: int = 0") 1656 .Attr("shrink_axis_mask: int = 0") 1657 .SetShapeFn(shape_inference::UnchangedShape); 1658 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports 1659 // broadcasting. 1660 // -------------------------------------------------------------------------- 1661 1662 REGISTER_OP("ResourceStridedSliceAssign") 1663 .Input("ref: resource") 1664 .Input("begin: Index") 1665 .Input("end: Index") 1666 .Input("strides: Index") 1667 .Input("value: T") 1668 .Attr("T: type") 1669 .Attr("Index: {int32, int64}") 1670 .Attr("begin_mask: int = 0") 1671 .Attr("end_mask: int = 0") 1672 .Attr("ellipsis_mask: int = 0") 1673 .Attr("new_axis_mask: int = 0") 1674 .Attr("shrink_axis_mask: int = 0") 1675 .SetShapeFn(shape_inference::NoOutputs); 1676 1677 REGISTER_OP("Tile") 1678 .Input("input: T") 1679 .Input("multiples: Tmultiples") 1680 .Output("output: T") 1681 .Attr("T: type") 1682 .Attr("Tmultiples: {int32, int64} = DT_INT32") 1683 .SetShapeFn([](InferenceContext* c) { 1684 ShapeHandle input = c->input(0); 1685 // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i) 1686 // it is a vector of non-negative integers, and (ii) doing so allows 1687 // us to handle partially-known multiples. 1688 ShapeHandle multiples; 1689 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples)); 1690 if (c->RankKnown(input)) { 1691 TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples)); 1692 ShapeHandle dummy; 1693 TF_RETURN_IF_ERROR( 1694 c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy)); 1695 } 1696 1697 if (!c->RankKnown(multiples)) { 1698 return shape_inference::UnknownShape(c); 1699 } 1700 1701 int32 rank = c->Rank(multiples); 1702 TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); 1703 std::vector<DimensionHandle> dims(rank); 1704 for (int i = 0; i < rank; ++i) { 1705 TF_RETURN_IF_ERROR( 1706 c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i])); 1707 } 1708 c->set_output(0, c->MakeShape(dims)); 1709 return Status::OK(); 1710 }); 1711 1712 // -------------------------------------------------------------------------- 1713 REGISTER_OP("TileGrad") 1714 .Input("input: T") 1715 .Input("multiples: int32") 1716 .Output("output: T") 1717 .Attr("T: type") 1718 .Deprecated(3, "TileGrad has been replaced with reduce_sum") 1719 .SetShapeFn(tensorflow::shape_inference::UnknownShape); 1720 1721 // -------------------------------------------------------------------------- 1722 REGISTER_OP("Where") 1723 .Input("input: T") 1724 .Attr("T: {numbertype, bool} = DT_BOOL") 1725 .Output("index: int64") 1726 .SetShapeFn([](InferenceContext* c) { 1727 c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); 1728 return Status::OK(); 1729 }); 1730 1731 // -------------------------------------------------------------------------- 1732 REGISTER_OP("BroadcastArgs") 1733 .Input("s0: T") 1734 .Input("s1: T") 1735 .Output("r0: T") 1736 .Attr("T: {int32, int64} = DT_INT32") 1737 .SetShapeFn([](InferenceContext* c) { 1738 ShapeHandle unused; 1739 ShapeHandle shape_x = c->input(0); 1740 ShapeHandle shape_y = c->input(1); 1741 TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused)); 1742 TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused)); 1743 1744 if (!c->ValueKnown(c->Dim(shape_x, 0)) || 1745 !c->ValueKnown(c->Dim(shape_y, 0))) { 1746 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1747 return Status::OK(); 1748 } 1749 1750 int64 x_dim = c->Value(c->Dim(shape_x, 0)); 1751 int64 y_dim = c->Value(c->Dim(shape_y, 0)); 1752 1753 // Broadcasted shape is going to be as large as the largest dimension. 1754 c->set_output(0, c->Vector(std::max(x_dim, y_dim))); 1755 return Status::OK(); 1756 }); 1757 1758 // -------------------------------------------------------------------------- 1759 REGISTER_OP("BroadcastGradientArgs") 1760 .Input("s0: T") 1761 .Input("s1: T") 1762 .Output("r0: T") 1763 .Output("r1: T") 1764 .Attr("T: {int32, int64} = DT_INT32") 1765 .SetShapeFn([](InferenceContext* c) { 1766 // TODO(mrry): Implement constant_value for BroadcastGradientArgs? 1767 ShapeHandle unused; 1768 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 1769 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 1770 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1771 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 1772 return Status::OK(); 1773 }); 1774 1775 // -------------------------------------------------------------------------- 1776 REGISTER_OP("Pad") 1777 .Input("input: T") 1778 .Input("paddings: Tpaddings") 1779 .Output("output: T") 1780 .Attr("T: type") 1781 .Attr("Tpaddings: {int32, int64} = DT_INT32") 1782 .SetShapeFn(PadShapeFn); 1783 1784 // -------------------------------------------------------------------------- 1785 REGISTER_OP("PadV2") 1786 .Input("input: T") 1787 .Input("paddings: Tpaddings") 1788 .Input("constant_values: T") 1789 .Output("output: T") 1790 .Attr("T: type") 1791 .Attr("Tpaddings: {int32, int64} = DT_INT32") 1792 .SetShapeFn(PadShapeFn); 1793 1794 // -------------------------------------------------------------------------- 1795 REGISTER_OP("MirrorPad") 1796 .Input("input: T") 1797 .Input("paddings: Tpaddings") 1798 .Output("output: T") 1799 .Attr("T: type") 1800 .Attr("Tpaddings: {int32, int64} = DT_INT32") 1801 .Attr(GetMirrorPadModeAttrString()) 1802 .SetShapeFn(PadShapeFn); 1803 1804 // -------------------------------------------------------------------------- 1805 namespace { 1806 template <typename T> 1807 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, 1808 const Tensor* paddings_t, int64 input_rank) { 1809 auto paddings_data = paddings_t->matrix<T>(); 1810 std::vector<DimensionHandle> dims(input_rank); 1811 for (int64 i = 0; i < input_rank; ++i) { 1812 const int64 pad0 = static_cast<int64>(paddings_data(i, 0)); 1813 const int64 pad1 = static_cast<int64>(paddings_data(i, 1)); 1814 if (pad0 < 0 || pad1 < 0) { 1815 return errors::InvalidArgument("Paddings must be non-negative"); 1816 } 1817 1818 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i])); 1819 } 1820 c->set_output(0, c->MakeShape(dims)); 1821 return Status::OK(); 1822 } 1823 1824 } // namespace 1825 1826 REGISTER_OP("MirrorPadGrad") 1827 .Input("input: T") 1828 .Input("paddings: Tpaddings") 1829 .Output("output: T") 1830 .Attr("T: type") 1831 .Attr("Tpaddings: {int32, int64} = DT_INT32") 1832 .Attr(GetMirrorPadModeAttrString()) 1833 .SetShapeFn([](InferenceContext* c) { 1834 ShapeHandle paddings; 1835 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); 1836 DimensionHandle pad_0 = c->Dim(paddings, 0); 1837 if (!c->ValueKnown(pad_0)) { 1838 // We don't know the rank of the output since the first 1839 // padding dimension is unknown. 1840 c->set_output(0, c->UnknownShape()); 1841 return Status::OK(); 1842 } 1843 1844 int64 input_rank = c->Value(pad_0); 1845 ShapeHandle input; 1846 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input)); 1847 TF_RETURN_IF_ERROR( 1848 c->Merge(paddings, c->Matrix(input_rank, 2), &paddings)); 1849 1850 const Tensor* paddings_t = c->input_tensor(1); 1851 if (paddings_t == nullptr) { 1852 // Values of 'paddings' is not available, but we know the 1853 // input rank, so return the rank of the output with unknown 1854 // dimensions. 1855 c->set_output(0, c->UnknownShapeOfRank(input_rank)); 1856 return Status::OK(); 1857 } 1858 1859 if (paddings_t->dtype() == DT_INT32) { 1860 return MirrorPadKnown<int32>(c, input, paddings_t, input_rank); 1861 } else { 1862 return MirrorPadKnown<int64>(c, input, paddings_t, input_rank); 1863 } 1864 }); 1865 1866 // -------------------------------------------------------------------------- 1867 REGISTER_OP("Placeholder") 1868 .Output("output: dtype") 1869 .Attr("dtype: type") 1870 .Attr("shape: shape = { unknown_rank: true }") 1871 .SetShapeFn([](InferenceContext* c) { 1872 PartialTensorShape shape; 1873 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 1874 1875 // Placeholder has legacy behavior where we cannot tell the difference 1876 // between a scalar shape attribute and 'unknown shape'. So if the shape 1877 // is a scalar, we return an unknown shape. 1878 if (c->graph_def_version() <= 21 && shape.dims() <= 0) { 1879 return shape_inference::UnknownShape(c); 1880 } 1881 1882 ShapeHandle out; 1883 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); 1884 c->set_output(0, out); 1885 return Status::OK(); 1886 }); 1887 1888 // Placeholder was modified in a backwards compatible way to do what 1889 // PlaceholderV2 did, so we have deprecated V2 (no one was really 1890 // using it). 1891 REGISTER_OP("PlaceholderV2") 1892 .Output("output: dtype") 1893 .Attr("dtype: type") 1894 .Attr("shape: shape") 1895 .SetShapeFn(shape_inference::ExplicitShape) 1896 .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2."); 1897 1898 // -------------------------------------------------------------------------- 1899 REGISTER_OP("PlaceholderWithDefault") 1900 .Input("input: dtype") 1901 .Output("output: dtype") 1902 .Attr("dtype: type") 1903 .Attr("shape: shape") 1904 .SetShapeFn([](InferenceContext* c) { 1905 ShapeHandle input = c->input(0); 1906 PartialTensorShape shape; 1907 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 1908 ShapeHandle out; 1909 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); 1910 1911 // We merge for compatibility checking, but return the output, 1912 // since output_shape may be less precise than input_shape. 1913 ShapeHandle unused; 1914 TF_RETURN_IF_ERROR(c->Merge(input, out, &unused)); 1915 c->set_output(0, out); 1916 return Status::OK(); 1917 }); 1918 1919 // -------------------------------------------------------------------------- 1920 REGISTER_OP("ExpandDims") 1921 .Input("input: T") 1922 .Input("dim: Tdim") 1923 .Output("output: T") 1924 .Attr("T: type") 1925 .Attr("Tdim: {int32, int64} = DT_INT32") 1926 .SetShapeFn([](InferenceContext* c) { 1927 ShapeHandle input = c->input(0); 1928 1929 const Tensor* dim_t = c->input_tensor(1); 1930 if (dim_t != nullptr && dim_t->NumElements() != 1) { 1931 return errors::InvalidArgument( 1932 "'dim' input must be a tensor with a single value"); 1933 } 1934 if (dim_t == nullptr || !c->RankKnown(input)) { 1935 c->set_output(0, c->UnknownShape()); 1936 return Status::OK(); 1937 } 1938 1939 int64 dim; 1940 if (dim_t->dtype() == DT_INT32) { 1941 dim = static_cast<int64>(dim_t->flat<int32>()(0)); 1942 } else { 1943 dim = dim_t->flat<int64>()(0); 1944 } 1945 1946 const int32 rank = c->Rank(input); 1947 const int32 min_dim = -1 * rank - 1; 1948 if (dim < min_dim || dim > rank) { 1949 return errors::InvalidArgument("dim ", dim, " not in the interval [", 1950 min_dim, ", ", rank, "]."); 1951 } 1952 1953 if (dim < 0) { 1954 dim += rank + 1; 1955 } 1956 1957 ShapeHandle end; 1958 TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end)); 1959 1960 // Build output as start + 1 + end. 1961 ShapeHandle output; 1962 TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output)); 1963 TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output)); 1964 TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output)); 1965 c->set_output(0, output); 1966 return Status::OK(); 1967 }); 1968 1969 // -------------------------------------------------------------------------- 1970 REGISTER_OP("Squeeze") 1971 .Input("input: T") 1972 .Output("output: T") 1973 .Attr("T: type") 1974 .Attr("squeeze_dims: list(int) >= 0 = []") 1975 .SetShapeFn([](InferenceContext* c) { 1976 ShapeHandle input = c->input(0); 1977 if (!c->RankKnown(input)) { 1978 // Input shape unknown. 1979 return shape_inference::UnknownShape(c); 1980 } 1981 1982 const int32 input_rank = c->Rank(input); 1983 1984 // Validate and wrap squeeze dimensions. 1985 std::vector<int32> squeeze_dims; 1986 TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims)); 1987 for (int i = 0; i < squeeze_dims.size(); ++i) { 1988 if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) { 1989 return errors::InvalidArgument("squeeze_dims[", i, "] not in [", 1990 -input_rank, ",", input_rank, ")."); 1991 } 1992 1993 if (squeeze_dims[i] < 0) { 1994 squeeze_dims[i] += input_rank; 1995 } 1996 } 1997 1998 std::vector<DimensionHandle> result_shape; 1999 for (int i = 0; i < input_rank; ++i) { 2000 // True if squeeze_dims contains an entry to squeeze this 2001 // dimension. 2002 bool is_explicit_match = 2003 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) != 2004 squeeze_dims.end(); 2005 2006 DimensionHandle dim = c->Dim(input, i); 2007 2008 if (!c->ValueKnown(dim)) { 2009 // Assume that the squeezed dimension will be 1 at runtime. 2010 if (is_explicit_match) continue; 2011 2012 // If squeezing all 1 dimensions, and we see an unknown value, 2013 // give up and return Unknown Shape. 2014 if (squeeze_dims.empty()) { 2015 c->set_output(0, c->UnknownShape()); 2016 return Status::OK(); 2017 } 2018 } else if (c->Value(dim) == 1) { 2019 if (is_explicit_match || squeeze_dims.empty()) { 2020 // If explicitly squeezing, or squeezing all 1s, remove 2021 // this dimension. 2022 continue; 2023 } 2024 } else if (is_explicit_match) { 2025 return errors::InvalidArgument("Can not squeeze dim[", i, 2026 "], expected a dimension of 1, got ", 2027 c->Value(c->Dim(input, i))); 2028 } 2029 2030 result_shape.emplace_back(dim); 2031 } 2032 2033 c->set_output(0, c->MakeShape(result_shape)); 2034 return Status::OK(); 2035 }); 2036 2037 // -------------------------------------------------------------------------- 2038 REGISTER_OP("ListDiff") 2039 .Input("x: T") 2040 .Input("y: T") 2041 .Output("out: T") 2042 .Output("idx: out_idx") 2043 .Attr("T: type") 2044 .Attr("out_idx: {int32, int64} = DT_INT32") 2045 .SetShapeFn([](InferenceContext* c) { 2046 ShapeHandle unused; 2047 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 2048 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 2049 // TODO(mrry): Indicate that the length falls within an interval? 2050 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); 2051 c->set_output(0, out); 2052 c->set_output(1, out); 2053 return Status::OK(); 2054 }); 2055 2056 namespace { 2057 2058 // Converts Tensor to flat std::vector<int64>. 2059 template <typename InputType> 2060 std::vector<int64> GetFlatInt64(const Tensor& t) { 2061 std::vector<int64> output(t.shape().num_elements()); 2062 auto eigen_vec = t.flat<InputType>(); 2063 std::copy_n(&eigen_vec(0), output.size(), output.begin()); 2064 return output; 2065 } 2066 2067 // Converts int32 or int64 Tensor to flat std::vector<int64>. 2068 std::vector<int64> GetFlatInt64(const Tensor& t) { 2069 if (t.dtype() == DT_INT32) { 2070 return GetFlatInt64<int32>(t); 2071 } else { 2072 return GetFlatInt64<int64>(t); 2073 } 2074 } 2075 2076 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape, 2077 ShapeHandle block_shape_shape, 2078 const Tensor* block_shape_t, 2079 ShapeHandle paddings_shape, 2080 const Tensor* paddings_t) { 2081 if (c->Rank(block_shape_shape) != 1) { 2082 return errors::InvalidArgument("block_shape must have rank 1."); 2083 } 2084 2085 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); 2086 if (!c->ValueKnown(num_block_dims_handle)) { 2087 return errors::InvalidArgument("block_shape must have known size."); 2088 } 2089 2090 const int64 num_block_dims = c->Value(num_block_dims_handle); 2091 2092 TF_RETURN_IF_ERROR( 2093 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); 2094 2095 TF_RETURN_IF_ERROR( 2096 c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape)); 2097 2098 DimensionHandle batch_size = c->Dim(input_shape, 0); 2099 std::vector<int64> block_shape_vec; 2100 if (block_shape_t) { 2101 block_shape_vec = GetFlatInt64(*block_shape_t); 2102 for (int64 dim = 0; dim < num_block_dims; ++dim) { 2103 const int64 block_shape_value = block_shape_vec[dim]; 2104 if (block_shape_value < 1) { 2105 return errors::InvalidArgument("block_shape must be positive"); 2106 } 2107 if (c->ValueKnown(batch_size)) { 2108 TF_RETURN_IF_ERROR( 2109 c->Multiply(batch_size, block_shape_value, &batch_size)); 2110 } else { 2111 batch_size = c->UnknownDim(); 2112 } 2113 } 2114 } else if (num_block_dims > 0) { 2115 batch_size = c->UnknownDim(); 2116 } 2117 2118 std::vector<DimensionHandle> output_dims{batch_size}; 2119 output_dims.resize(num_block_dims + 1, c->UnknownDim()); 2120 2121 if (paddings_t) { 2122 const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t); 2123 for (int64 dim = 0; dim < num_block_dims; ++dim) { 2124 const int64 pad_start = paddings_vec[dim * 2], 2125 pad_end = paddings_vec[dim * 2 + 1]; 2126 if (pad_start < 0 || pad_end < 0) { 2127 return errors::InvalidArgument("paddings cannot be negative"); 2128 } 2129 if (block_shape_t) { 2130 DimensionHandle padded_size; 2131 TF_RETURN_IF_ERROR( 2132 c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size)); 2133 TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size)); 2134 TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim], 2135 /*evenly_divisible=*/true, 2136 &output_dims[dim + 1])); 2137 } 2138 } 2139 } 2140 2141 ShapeHandle remaining_input_shape; 2142 TF_RETURN_IF_ERROR( 2143 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); 2144 2145 ShapeHandle result; 2146 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), 2147 remaining_input_shape, &result)); 2148 c->set_output(0, result); 2149 return Status::OK(); 2150 } 2151 2152 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape, 2153 ShapeHandle block_shape_shape, 2154 const Tensor* block_shape_t, 2155 ShapeHandle crops_shape, const Tensor* crops_t) { 2156 if (c->Rank(block_shape_shape) != 1) { 2157 return errors::InvalidArgument("block_shape must have rank 1."); 2158 } 2159 2160 const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0); 2161 if (!c->ValueKnown(num_block_dims_handle)) { 2162 return errors::InvalidArgument("block_shape must have known size."); 2163 } 2164 2165 const int64 num_block_dims = c->Value(num_block_dims_handle); 2166 2167 TF_RETURN_IF_ERROR( 2168 c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape)); 2169 2170 TF_RETURN_IF_ERROR( 2171 c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape)); 2172 2173 DimensionHandle batch_size = c->Dim(input_shape, 0); 2174 std::vector<int64> block_shape_vec; 2175 if (block_shape_t) { 2176 block_shape_vec = GetFlatInt64(*block_shape_t); 2177 for (int64 dim = 0; dim < num_block_dims; ++dim) { 2178 const int64 block_shape_value = block_shape_vec[dim]; 2179 if (block_shape_value < 1) { 2180 return errors::InvalidArgument("block_shape must be positive"); 2181 } 2182 if (c->ValueKnown(batch_size)) { 2183 TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value, 2184 /*evenly_divisible=*/true, &batch_size)); 2185 } else { 2186 batch_size = c->UnknownDim(); 2187 } 2188 } 2189 } else if (num_block_dims > 0) { 2190 batch_size = c->UnknownDim(); 2191 } 2192 2193 std::vector<DimensionHandle> output_dims{batch_size}; 2194 output_dims.resize(num_block_dims + 1, c->UnknownDim()); 2195 2196 if (crops_t) { 2197 const std::vector<int64> crops_vec = GetFlatInt64(*crops_t); 2198 for (int64 dim = 0; dim < num_block_dims; ++dim) { 2199 const int64 crop_start = crops_vec[dim * 2], 2200 crop_end = crops_vec[dim * 2 + 1]; 2201 if (crop_start < 0 || crop_end < 0) { 2202 return errors::InvalidArgument("crops cannot be negative"); 2203 } 2204 if (block_shape_t) { 2205 DimensionHandle cropped_size; 2206 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1), 2207 block_shape_vec[dim], &cropped_size)); 2208 TF_RETURN_IF_ERROR( 2209 c->Subtract(cropped_size, crop_start, &cropped_size)); 2210 TF_RETURN_IF_ERROR( 2211 c->Subtract(cropped_size, crop_end, &output_dims[dim + 1])); 2212 } 2213 } 2214 } 2215 2216 ShapeHandle remaining_input_shape; 2217 TF_RETURN_IF_ERROR( 2218 c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape)); 2219 2220 ShapeHandle result; 2221 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims), 2222 remaining_input_shape, &result)); 2223 c->set_output(0, result); 2224 return Status::OK(); 2225 } 2226 2227 } // namespace 2228 2229 // -------------------------------------------------------------------------- 2230 REGISTER_OP("SpaceToBatchND") 2231 .Input("input: T") 2232 .Input("block_shape: Tblock_shape") 2233 .Input("paddings: Tpaddings") 2234 .Output("output: T") 2235 .Attr("T: type") 2236 .Attr("Tblock_shape: {int32, int64} = DT_INT32") 2237 .Attr("Tpaddings: {int32, int64} = DT_INT32") 2238 .SetShapeFn([](InferenceContext* c) { 2239 return SpaceToBatchShapeHelper(c, c->input(0), c->input(1), 2240 c->input_tensor(1), c->input(2), 2241 c->input_tensor(2)); 2242 }); 2243 2244 // -------------------------------------------------------------------------- 2245 REGISTER_OP("SpaceToBatch") 2246 .Input("input: T") 2247 .Input("paddings: Tpaddings") 2248 .Output("output: T") 2249 .Attr("T: type") 2250 .Attr("Tpaddings: {int32, int64} = DT_INT32") 2251 .Attr("block_size: int >= 2") 2252 .SetShapeFn([](InferenceContext* c) { 2253 ShapeHandle input_shape; 2254 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 2255 2256 int32 block_size; 2257 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); 2258 2259 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); 2260 auto block_shape_vec = block_shape.vec<int64>(); 2261 block_shape_vec(0) = block_size; 2262 block_shape_vec(1) = block_size; 2263 2264 return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}), 2265 &block_shape, c->input(1), 2266 c->input_tensor(1)); 2267 }); 2268 2269 // -------------------------------------------------------------------------- 2270 REGISTER_OP("BatchToSpaceND") 2271 .Input("input: T") 2272 .Input("block_shape: Tblock_shape") 2273 .Input("crops: Tcrops") 2274 .Output("output: T") 2275 .Attr("T: type") 2276 .Attr("Tblock_shape: {int32, int64} = DT_INT32") 2277 .Attr("Tcrops: {int32, int64} = DT_INT32") 2278 .SetShapeFn([](InferenceContext* c) { 2279 return BatchToSpaceShapeHelper(c, c->input(0), c->input(1), 2280 c->input_tensor(1), c->input(2), 2281 c->input_tensor(2)); 2282 }); 2283 2284 // -------------------------------------------------------------------------- 2285 REGISTER_OP("BatchToSpace") 2286 .Input("input: T") 2287 .Input("crops: Tidx") 2288 .Output("output: T") 2289 .Attr("T: type") 2290 .Attr("block_size: int >= 2") 2291 .Attr("Tidx: {int32, int64} = DT_INT32") 2292 .SetShapeFn([](InferenceContext* c) { 2293 ShapeHandle input_shape; 2294 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 2295 2296 int32 block_size; 2297 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); 2298 2299 Tensor block_shape(tensorflow::DT_INT64, TensorShape({2})); 2300 auto block_shape_vec = block_shape.vec<int64>(); 2301 block_shape_vec(0) = block_size; 2302 block_shape_vec(1) = block_size; 2303 2304 return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}), 2305 &block_shape, c->input(1), 2306 c->input_tensor(1)); 2307 }); 2308 2309 // -------------------------------------------------------------------------- 2310 REGISTER_OP("SpaceToDepth") 2311 .Input("input: T") 2312 .Output("output: T") 2313 .Attr("T: type") 2314 .Attr("block_size: int >= 2") 2315 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") 2316 // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C. 2317 .SetShapeFn([](InferenceContext* c) { 2318 string data_format_str; 2319 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); 2320 TensorFormat data_format; 2321 FormatFromString(data_format_str, &data_format); 2322 2323 constexpr int num_spatial_dims = 2; 2324 const int dims = 2325 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); 2326 ShapeHandle input; 2327 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); 2328 2329 int32 block_size; 2330 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); 2331 2332 DimensionHandle batch_size = 2333 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 2334 DimensionHandle input_height = 2335 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 2336 DimensionHandle input_width = 2337 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 2338 DimensionHandle input_depth = 2339 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 2340 2341 DimensionHandle output_height; 2342 DimensionHandle output_width; 2343 DimensionHandle output_depth; 2344 // Will return an error if input height or width are not evenly divisible. 2345 TF_RETURN_IF_ERROR(c->Divide(input_height, block_size, 2346 true /* evenly_divisible */, 2347 &output_height)); 2348 TF_RETURN_IF_ERROR(c->Divide(input_width, block_size, 2349 true /* evenly_divisible */, &output_width)); 2350 2351 TF_RETURN_IF_ERROR( 2352 c->Multiply(input_depth, block_size * block_size, &output_depth)); 2353 2354 ShapeHandle output_shape; 2355 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, 2356 {output_height, output_width}, 2357 output_depth, &output_shape, c)); 2358 2359 c->set_output(0, output_shape); 2360 return Status::OK(); 2361 }); 2362 2363 // -------------------------------------------------------------------------- 2364 REGISTER_OP("DepthToSpace") 2365 .Input("input: T") 2366 .Output("output: T") 2367 .Attr("T: type") 2368 .Attr("block_size: int >= 2") 2369 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") 2370 // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C. 2371 .SetShapeFn([](InferenceContext* c) { 2372 string data_format_str; 2373 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); 2374 TensorFormat data_format; 2375 FormatFromString(data_format_str, &data_format); 2376 2377 constexpr int num_spatial_dims = 2; 2378 const int dims = 2379 GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); 2380 2381 ShapeHandle input; 2382 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input)); 2383 2384 int32 block_size; 2385 TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size)); 2386 2387 DimensionHandle batch_size = 2388 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 2389 DimensionHandle input_height = 2390 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 2391 DimensionHandle input_width = 2392 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 2393 DimensionHandle input_depth = 2394 c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 2395 2396 DimensionHandle output_height; 2397 DimensionHandle output_width; 2398 DimensionHandle output_depth; 2399 TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height)); 2400 TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width)); 2401 2402 // Will return an error if input_depth is not evenly divisible. 2403 TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size, 2404 true /* evenly_divisible */, &output_depth)); 2405 2406 ShapeHandle output_shape; 2407 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size, 2408 {output_height, output_width}, 2409 output_depth, &output_shape, c)); 2410 2411 c->set_output(0, output_shape); 2412 return Status::OK(); 2413 }); 2414 2415 // -------------------------------------------------------------------------- 2416 2417 REGISTER_OP("ExtractImagePatches") 2418 .Input("images: T") 2419 .Output("patches: T") 2420 .Attr("ksizes: list(int) >= 4") 2421 .Attr("strides: list(int) >= 4") 2422 .Attr("rates: list(int) >= 4") 2423 .Attr("T: realnumbertype") 2424 .Attr(GetPaddingAttrString()) 2425 .SetShapeFn([](InferenceContext* c) { 2426 ShapeHandle input_shape; 2427 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 2428 2429 std::vector<int32> ksizes; 2430 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); 2431 if (ksizes.size() != 4) { 2432 return errors::InvalidArgument( 2433 "ExtractImagePatches requires the ksizes attribute to contain 4 " 2434 "values, but got: ", 2435 ksizes.size()); 2436 } 2437 2438 std::vector<int32> strides; 2439 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 2440 if (strides.size() != 4) { 2441 return errors::InvalidArgument( 2442 "ExtractImagePatches requires the stride attribute to contain 4 " 2443 "values, but got: ", 2444 strides.size()); 2445 } 2446 2447 std::vector<int32> rates; 2448 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); 2449 if (rates.size() != 4) { 2450 return errors::InvalidArgument( 2451 "ExtractImagePatches requires the rates attribute to contain 4 " 2452 "values, but got: ", 2453 rates.size()); 2454 } 2455 2456 int32 ksize_rows = ksizes[1]; 2457 int32 ksize_cols = ksizes[2]; 2458 2459 int32 stride_rows = strides[1]; 2460 int32 stride_cols = strides[2]; 2461 2462 int32 rate_rows = rates[1]; 2463 int32 rate_cols = rates[2]; 2464 2465 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); 2466 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); 2467 2468 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 2469 DimensionHandle in_rows_dim = c->Dim(input_shape, 1); 2470 DimensionHandle in_cols_dim = c->Dim(input_shape, 2); 2471 DimensionHandle output_depth_dim; 2472 TF_RETURN_IF_ERROR(c->Multiply( 2473 c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim)); 2474 2475 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) { 2476 ShapeHandle output_shape = 2477 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, 2478 InferenceContext::kUnknownDim, output_depth_dim}); 2479 c->set_output(0, output_shape); 2480 return Status::OK(); 2481 } 2482 auto in_rows = c->Value(in_rows_dim); 2483 auto in_cols = c->Value(in_cols_dim); 2484 2485 Padding padding; 2486 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 2487 2488 int64 output_rows, output_cols; 2489 int64 padding_before, padding_after; 2490 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 2491 in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, 2492 &padding_before, &padding_after)); 2493 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 2494 in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, 2495 &padding_before, &padding_after)); 2496 ShapeHandle output_shape = c->MakeShape( 2497 {batch_size_dim, output_rows, output_cols, output_depth_dim}); 2498 c->set_output(0, output_shape); 2499 return Status::OK(); 2500 }); 2501 2502 // -------------------------------------------------------------------------- 2503 2504 // To enable rates, uncomment all lines commented below and use ksize_*_eff 2505 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead 2506 // of ksize_*. 2507 REGISTER_OP("ExtractVolumePatches") 2508 .Input("input: T") 2509 .Output("patches: T") 2510 .Attr("ksizes: list(int) >= 5") 2511 .Attr("strides: list(int) >= 5") 2512 /* .Attr("rates: list(int) >= 5") */ 2513 .Attr("T: realnumbertype") 2514 .Attr(GetPaddingAttrString()) 2515 .SetShapeFn([](InferenceContext* c) { 2516 ShapeHandle input_shape; 2517 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); 2518 2519 std::vector<int32> ksizes; 2520 TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); 2521 if (ksizes.size() != 5) { 2522 return errors::InvalidArgument( 2523 "ExtractVolumePatches requires the ksizes attribute to contain 5 " 2524 "values, but got: ", 2525 ksizes.size()); 2526 } 2527 2528 std::vector<int32> strides; 2529 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 2530 if (strides.size() != 5) { 2531 return errors::InvalidArgument( 2532 "ExtractVolumePatches requires the stride attribute to contain 5 " 2533 "values, but got: ", 2534 strides.size()); 2535 } 2536 2537 /* 2538 // TODO(hsgkim): Enable rates. 2539 // See extract_volume_patches_op.cc for why rates are disabled now. 2540 2541 std::vector<int32> rates; 2542 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); 2543 if (rates.size() != 5) { 2544 return errors::InvalidArgument( 2545 "ExtractVolumePatches requires the rates attribute to contain 5 " 2546 "values, but got: ", 2547 rates.size()); 2548 } 2549 */ 2550 2551 int32 ksize_planes = ksizes[1]; 2552 int32 ksize_rows = ksizes[2]; 2553 int32 ksize_cols = ksizes[3]; 2554 2555 int32 stride_planes = strides[1]; 2556 int32 stride_rows = strides[2]; 2557 int32 stride_cols = strides[3]; 2558 2559 /* 2560 int32 rate_planes = rates[1]; 2561 int32 rate_rows = rates[2]; 2562 int32 rate_cols = rates[3]; 2563 2564 int32 ksize_planes_eff = ksize_planes + 2565 (ksize_planes - 1) * (rate_planes - 1); 2566 int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); 2567 int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); 2568 */ 2569 2570 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 2571 DimensionHandle in_planes_dim = c->Dim(input_shape, 1); 2572 DimensionHandle in_rows_dim = c->Dim(input_shape, 2); 2573 DimensionHandle in_cols_dim = c->Dim(input_shape, 3); 2574 DimensionHandle output_depth_dim; 2575 TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4), 2576 ksize_planes * ksize_rows * ksize_cols, 2577 &output_depth_dim)); 2578 2579 if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || 2580 !c->ValueKnown(in_cols_dim)) { 2581 ShapeHandle output_shape = 2582 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, 2583 InferenceContext::kUnknownDim, output_depth_dim}); 2584 c->set_output(0, output_shape); 2585 return Status::OK(); 2586 } 2587 auto in_planes = c->Value(in_planes_dim); 2588 auto in_rows = c->Value(in_rows_dim); 2589 auto in_cols = c->Value(in_cols_dim); 2590 2591 Padding padding; 2592 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 2593 2594 int64 output_planes, output_rows, output_cols; 2595 int64 padding_before, padding_after; 2596 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 2597 in_planes, ksize_planes, stride_planes, padding, &output_planes, 2598 &padding_before, &padding_after)); 2599 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 2600 in_rows, ksize_rows, stride_rows, padding, &output_rows, 2601 &padding_before, &padding_after)); 2602 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 2603 in_cols, ksize_cols, stride_cols, padding, &output_cols, 2604 &padding_before, &padding_after)); 2605 ShapeHandle output_shape = 2606 c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, 2607 output_depth_dim}); 2608 c->set_output(0, output_shape); 2609 return Status::OK(); 2610 }); 2611 2612 // -------------------------------------------------------------------------- 2613 2614 REGISTER_OP("Bitcast") 2615 .Input("input: T") 2616 .Output("output: type") 2617 // All supported dtypes are listed here to include qint16, quint16, uint32, 2618 // and uint64. 2619 .Attr( 2620 "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, " 2621 "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, " 2622 "qint16, quint16, qint32}") 2623 .Attr( 2624 "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, " 2625 "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, " 2626 "qint16, quint16, qint32}") 2627 .SetShapeFn([](InferenceContext* c) { 2628 ShapeHandle input = c->input(0); 2629 if (!c->RankKnown(input)) { 2630 // Input shape unknown. 2631 return shape_inference::UnknownShape(c); 2632 } 2633 2634 // Find the size of the input and output data types. 2635 DataType input_type; 2636 DataType output_type; 2637 TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type)); 2638 TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type)); 2639 const int input_type_size = DataTypeSize(input_type); 2640 const int output_type_size = DataTypeSize(output_type); 2641 2642 if (input_type_size == 0 || output_type_size == 0) { 2643 return errors::InvalidArgument("Cannot bitcast types ", 2644 DataTypeString(input_type), " to ", 2645 DataTypeString(output_type), 2646 " because " 2647 "one of the type sizes is zero."); 2648 } 2649 2650 ShapeHandle new_shape; 2651 if (input_type_size == output_type_size) { 2652 // No change in size. 2653 new_shape = input; 2654 } else if (input_type_size < output_type_size) { 2655 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape)); 2656 2657 int64 divisor_val = output_type_size / input_type_size; 2658 DimensionHandle last_dim = c->Dim(new_shape, -1); 2659 if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) { 2660 TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape)); 2661 } else { 2662 return errors::InvalidArgument("Cannot bitcast due to shape. ", 2663 c->Value(last_dim), " does not match ", 2664 divisor_val); 2665 } 2666 } else { 2667 // Input type size is larger than output type size. 2668 int64 divisor_val = input_type_size / output_type_size; 2669 ShapeHandle extension = c->Vector(divisor_val); 2670 TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape)); 2671 } 2672 2673 c->set_output(0, new_shape); 2674 return Status::OK(); 2675 }); 2676 2677 REGISTER_OP("OneHot") 2678 .Input("indices: TI") 2679 .Input("depth: int32") 2680 .Input("on_value: T") 2681 .Input("off_value: T") 2682 .Attr("axis: int = -1") 2683 .Output("output: T") 2684 .Attr("T: type") 2685 .Attr("TI: {uint8, int32, int64} = DT_INT64") 2686 .SetShapeFn([](InferenceContext* c) { 2687 int32 axis; 2688 TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); 2689 if (axis < -1) return errors::InvalidArgument("axis must be >= -1"); 2690 2691 DimensionHandle depth; 2692 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth)); 2693 2694 ShapeHandle indices = c->input(0); 2695 if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c); 2696 2697 int32 new_rank = c->Rank(indices) + 1; 2698 // We need to add new_rank to axis in the case the axis is -1 because 2699 // C++ returns negative values from % if the dividend is negative. 2700 int32 depth_index = (axis + new_rank) % new_rank; 2701 // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:]. 2702 ShapeHandle front; 2703 ShapeHandle back; 2704 ShapeHandle out; 2705 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front)); 2706 TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back)); 2707 TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front)); 2708 TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out)); 2709 c->set_output(0, out); 2710 return Status::OK(); 2711 }); 2712 2713 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 2714 REGISTER_OP("QuantizeAndDequantize") 2715 .Input("input: T") 2716 .Attr("signed_input: bool = true") 2717 .Attr("num_bits: int = 8") 2718 .Attr("range_given: bool = false") 2719 .Attr("input_min: float = 0") 2720 .Attr("input_max: float = 0") 2721 .Output("output: T") 2722 .Attr("T: {bfloat16, half, float, double}") 2723 .SetShapeFn(shape_inference::UnchangedShape) 2724 .Deprecated(22, "Replaced by QuantizeAndDequantizeV2"); 2725 2726 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2. 2727 REGISTER_OP("QuantizeAndDequantizeV2") 2728 .Input("input: T") 2729 .Input("input_min: T") 2730 .Input("input_max: T") 2731 .Attr("signed_input: bool = true") 2732 .Attr("num_bits: int = 8") 2733 .Attr("range_given: bool = false") 2734 .Output("output: T") 2735 .Attr("T: {bfloat16, half, float, double}") 2736 .Attr( 2737 "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = " 2738 "'HALF_TO_EVEN'") 2739 .SetShapeFn([](InferenceContext* c) { 2740 ShapeHandle unused; 2741 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 2742 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2743 c->set_output(0, c->input(0)); 2744 return Status::OK(); 2745 }); 2746 2747 REGISTER_OP("QuantizeAndDequantizeV3") 2748 .Input("input: T") 2749 .Input("input_min: T") 2750 .Input("input_max: T") 2751 .Input("num_bits: int32") 2752 .Attr("signed_input: bool = true") 2753 .Attr("range_given: bool = true") 2754 .Output("output: T") 2755 .Attr("T: {bfloat16, half, float, double}") 2756 .SetShapeFn([](InferenceContext* c) { 2757 ShapeHandle unused; 2758 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 2759 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2760 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 2761 c->set_output(0, c->input(0)); 2762 return Status::OK(); 2763 }); 2764 2765 REGISTER_OP("QuantizeV2") 2766 .Input("input: float") 2767 .Input("min_range: float") 2768 .Input("max_range: float") 2769 .Output("output: T") 2770 .Output("output_min: float") 2771 .Output("output_max: float") 2772 .Attr("T: quantizedtype") 2773 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") 2774 .Attr( 2775 "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = " 2776 "'HALF_AWAY_FROM_ZERO'") 2777 .SetShapeFn([](InferenceContext* c) { 2778 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 2779 ShapeHandle unused; 2780 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 2781 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2782 c->set_output(1, c->Scalar()); 2783 c->set_output(2, c->Scalar()); 2784 return Status::OK(); 2785 }); 2786 2787 REGISTER_OP("Dequantize") 2788 .Input("input: T") 2789 .Input("min_range: float") 2790 .Input("max_range: float") 2791 .Output("output: float") 2792 .Attr("T: quantizedtype") 2793 .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") 2794 .SetShapeFn([](InferenceContext* c) { 2795 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 2796 ShapeHandle unused; 2797 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 2798 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2799 return Status::OK(); 2800 }); 2801 2802 REGISTER_OP("QuantizedConcat") 2803 .Input("concat_dim: int32") 2804 .Input("values: N * T") 2805 .Input("input_mins: N * float32") 2806 .Input("input_maxes: N * float32") 2807 .Output("output: T") 2808 .Output("output_min: float") 2809 .Output("output_max: float") 2810 .Attr("N: int >= 2") 2811 .Attr("T: type") 2812 .SetShapeFn([](InferenceContext* c) { 2813 const int n = (c->num_inputs() - 1) / 3; 2814 TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n)); 2815 ShapeHandle unused; 2816 for (int i = n + 1; i < c->num_inputs(); ++i) { 2817 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); 2818 } 2819 c->set_output(1, c->Scalar()); 2820 c->set_output(2, c->Scalar()); 2821 return Status::OK(); 2822 }); 2823 2824 REGISTER_OP("QuantizedReshape") 2825 .Input("tensor: T") 2826 .Input("shape: Tshape") 2827 .Input("input_min: float") 2828 .Input("input_max: float") 2829 .Output("output: T") 2830 .Output("output_min: float") 2831 .Output("output_max: float") 2832 .Attr("T: type") 2833 .Attr("Tshape: {int32, int64} = DT_INT32") 2834 .SetShapeFn([](InferenceContext* c) { 2835 TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c)); 2836 ShapeHandle unused; 2837 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2838 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 2839 c->set_output(1, c->Scalar()); 2840 c->set_output(2, c->Scalar()); 2841 return Status::OK(); 2842 }); 2843 2844 REGISTER_OP("QuantizedInstanceNorm") 2845 .Input("x: T") 2846 .Input("x_min: float") 2847 .Input("x_max: float") 2848 .Output("y: T") 2849 .Output("y_min: float") 2850 .Output("y_max: float") 2851 .Attr("T: quantizedtype") 2852 .Attr("output_range_given: bool = false") 2853 .Attr("given_y_min: float = 0") 2854 .Attr("given_y_max: float = 0") 2855 .Attr("variance_epsilon: float = 1e-5") 2856 .Attr("min_separation: float = 1e-3") 2857 .SetShapeFn([](shape_inference::InferenceContext* c) { 2858 shape_inference::ShapeHandle unused; 2859 // x should be a rank 4 tensor. 2860 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused)); 2861 // Assert x_min and x_max are scalars (rank 0). 2862 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 2863 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 2864 // y has the same shape as x. 2865 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 2866 // y_min and y_max are scalars. 2867 c->set_output(1, c->Scalar()); 2868 c->set_output(2, c->Scalar()); 2869 return Status::OK(); 2870 }); 2871 2872 namespace { 2873 2874 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, 2875 ShapeHandle updates_shape, 2876 ShapeHandle output_shape) { 2877 if (c->Value(c->NumElements(output_shape)) == 0 && 2878 (c->Value(c->NumElements(indices_shape)) > 0 || 2879 c->Value(c->NumElements(updates_shape)) > 0)) { 2880 return errors::InvalidArgument( 2881 "Indices and updates specified for empty output shape"); 2882 } 2883 2884 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { 2885 const int64 outer_dims = c->Rank(indices_shape) - 1; 2886 const DimensionHandle ixdim = c->Dim(indices_shape, -1); 2887 2888 // We can only do more validation if the last dimension of indices 2889 // is a known value. 2890 if (c->ValueKnown(ixdim)) { 2891 int64 ix = c->Value(ixdim); 2892 ShapeHandle unused; 2893 ShapeHandle prefix_indices; 2894 TF_RETURN_IF_ERROR( 2895 c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); 2896 ShapeHandle prefix_updates; 2897 TF_RETURN_IF_ERROR( 2898 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); 2899 2900 Status s = c->Merge(prefix_indices, prefix_updates, &unused); 2901 if (!s.ok()) { 2902 return errors::InvalidArgument( 2903 "The outer ", outer_dims, 2904 " dimensions of indices.shape=", c->DebugString(indices_shape), 2905 " must match the outer ", outer_dims, 2906 " dimensions of updates.shape=", c->DebugString(updates_shape), 2907 ": ", s.error_message()); 2908 } 2909 2910 ShapeHandle suffix_output; 2911 TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output)); 2912 ShapeHandle suffix_updates; 2913 TF_RETURN_IF_ERROR( 2914 c->Subshape(updates_shape, outer_dims, &suffix_updates)); 2915 s = c->Merge(suffix_output, suffix_updates, &unused); 2916 if (!s.ok()) { 2917 return errors::InvalidArgument( 2918 "The inner ", c->Rank(output_shape) - ix, 2919 " dimensions of output.shape=", c->DebugString(output_shape), 2920 " must match the inner ", c->Rank(updates_shape) - outer_dims, 2921 " dimensions of updates.shape=", c->DebugString(updates_shape), 2922 ": ", s.error_message()); 2923 } 2924 } 2925 } 2926 2927 c->set_output(0, output_shape); 2928 return Status::OK(); 2929 } 2930 2931 Status ScatterNdShape(InferenceContext* c) { 2932 ShapeHandle indices_shape; 2933 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape)); 2934 ShapeHandle updates_shape; 2935 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); 2936 ShapeHandle output_shape; 2937 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape)); 2938 return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); 2939 } 2940 2941 Status ScatterNdTensorShape(InferenceContext* c) { 2942 ShapeHandle output_shape; 2943 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape)); 2944 ShapeHandle indices_shape; 2945 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); 2946 ShapeHandle updates_shape; 2947 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); 2948 return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape); 2949 } 2950 2951 } // namespace 2952 2953 REGISTER_OP("UpperBound") 2954 .Input("sorted_inputs: T") 2955 .Input("values: T") 2956 .Output("output: out_type") 2957 .Attr("T: type") 2958 .Attr("out_type: {int32, int64} = DT_INT32") 2959 .SetShapeFn([](InferenceContext* c) { 2960 ShapeHandle unused_shape; 2961 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); 2962 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); 2963 c->set_output(0, c->input(1)); 2964 return Status::OK(); 2965 }); 2966 2967 REGISTER_OP("LowerBound") 2968 .Input("sorted_inputs: T") 2969 .Input("values: T") 2970 .Output("output: out_type") 2971 .Attr("T: type") 2972 .Attr("out_type: {int32, int64} = DT_INT32") 2973 .SetShapeFn([](InferenceContext* c) { 2974 ShapeHandle unused_shape; 2975 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); 2976 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); 2977 c->set_output(0, c->input(1)); 2978 return Status::OK(); 2979 }); 2980 2981 REGISTER_OP("ScatterNd") 2982 .Input("indices: Tindices") 2983 .Input("updates: T") 2984 .Input("shape: Tindices") 2985 .Output("output: T") 2986 .Attr("T: type") 2987 .Attr("Tindices: {int32, int64}") 2988 .SetShapeFn(ScatterNdShape); 2989 2990 REGISTER_OP("TensorScatterUpdate") 2991 .Input("tensor: T") 2992 .Input("indices: Tindices") 2993 .Input("updates: T") 2994 .Output("output: T") 2995 .Attr("T: type") 2996 .Attr("Tindices: {int32, int64}") 2997 .SetShapeFn(ScatterNdTensorShape); 2998 2999 REGISTER_OP("TensorScatterAdd") 3000 .Input("tensor: T") 3001 .Input("indices: Tindices") 3002 .Input("updates: T") 3003 .Output("output: T") 3004 .Attr("T: type") 3005 .Attr("Tindices: {int32, int64}") 3006 .SetShapeFn(ScatterNdTensorShape); 3007 3008 REGISTER_OP("TensorScatterSub") 3009 .Input("tensor: T") 3010 .Input("indices: Tindices") 3011 .Input("updates: T") 3012 .Output("output: T") 3013 .Attr("T: type") 3014 .Attr("Tindices: {int32, int64}") 3015 .SetShapeFn(ScatterNdTensorShape); 3016 3017 REGISTER_OP("ScatterNdNonAliasingAdd") 3018 .Input("input: T") 3019 .Input("indices: Tindices") 3020 .Input("updates: T") 3021 .Output("output: T") 3022 .Attr("T: {numbertype, bool}") 3023 .Attr("Tindices: {int32, int64}") 3024 .SetShapeFn(shape_inference::ScatterNdUpdateShape); 3025 3026 REGISTER_OP("FakeQuantWithMinMaxArgs") 3027 .Attr("min: float = -6.0") 3028 .Attr("max: float = 6.0") 3029 .Attr("num_bits: int = 8") 3030 .Attr("narrow_range: bool = false") 3031 .Input("inputs: float") 3032 .Output("outputs: float") 3033 .SetShapeFn(shape_inference::UnchangedShape); 3034 3035 REGISTER_OP("FakeQuantWithMinMaxArgsGradient") 3036 .Attr("min: float = -6.0") 3037 .Attr("max: float = 6.0") 3038 .Attr("num_bits: int = 8") 3039 .Attr("narrow_range: bool = false") 3040 .Input("gradients: float") 3041 .Input("inputs: float") 3042 .Output("backprops: float") 3043 .SetShapeFn(shape_inference::UnchangedShape); 3044 3045 REGISTER_OP("FakeQuantWithMinMaxVars") 3046 .Attr("num_bits: int = 8") 3047 .Attr("narrow_range: bool = false") 3048 .Input("inputs: float") 3049 .Input("min: float") 3050 .Input("max: float") 3051 .Output("outputs: float") 3052 .SetShapeFn([](InferenceContext* c) { 3053 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 3054 ShapeHandle unused; 3055 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 3056 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 3057 return Status::OK(); 3058 }); 3059 3060 REGISTER_OP("FakeQuantWithMinMaxVarsGradient") 3061 .Attr("num_bits: int = 8") 3062 .Attr("narrow_range: bool = false") 3063 .Input("gradients: float") 3064 .Input("inputs: float") 3065 .Input("min: float") 3066 .Input("max: float") 3067 .Output("backprops_wrt_input: float") 3068 .Output("backprop_wrt_min: float") 3069 .Output("backprop_wrt_max: float") 3070 .SetShapeFn([](InferenceContext* c) { 3071 // gradients and inputs are same size. 3072 ShapeHandle inputs; 3073 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs)); 3074 3075 // min and max are scalars 3076 ShapeHandle min_max; 3077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max)); 3078 TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max)); 3079 3080 c->set_output(0, inputs); 3081 c->set_output(1, min_max); 3082 c->set_output(2, min_max); 3083 return Status::OK(); 3084 }); 3085 3086 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel") 3087 .Attr("num_bits: int = 8") 3088 .Attr("narrow_range: bool = false") 3089 .Input("inputs: float") 3090 .Input("min: float") 3091 .Input("max: float") 3092 .Output("outputs: float") 3093 .SetShapeFn([](InferenceContext* c) { 3094 ShapeHandle input, min, max; 3095 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 3096 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min)); 3097 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max)); 3098 3099 DimensionHandle unused; 3100 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused)); 3101 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused)); 3102 TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused)); 3103 3104 c->set_output(0, input); 3105 return Status::OK(); 3106 }); 3107 3108 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient") 3109 .Attr("num_bits: int = 8") 3110 .Attr("narrow_range: bool = false") 3111 .Input("gradients: float") 3112 .Input("inputs: float") 3113 .Input("min: float") 3114 .Input("max: float") 3115 .Output("backprops_wrt_input: float") 3116 .Output("backprop_wrt_min: float") 3117 .Output("backprop_wrt_max: float") 3118 .SetShapeFn([](InferenceContext* c) { 3119 ShapeHandle inputs; 3120 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs)); 3121 TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs)); 3122 TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs)); 3123 3124 ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1)); 3125 3126 ShapeHandle min_max; 3127 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max)); 3128 TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max)); 3129 TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max)); 3130 3131 c->set_output(0, inputs); 3132 c->set_output(1, min_max); 3133 c->set_output(2, min_max); 3134 return Status::OK(); 3135 }); 3136 3137 #ifdef INTEL_MKL 3138 REGISTER_OP("_MklConcat") 3139 .Input("concat_dim: int32") 3140 .Input("values: N * T") 3141 .Input("mkl_concat_dim: uint8") 3142 .Input("mkl_values: N * uint8") 3143 .Output("output: T") 3144 .Output("mkl_output: uint8") 3145 .Attr("N: int >= 2") 3146 .Attr("T: type") 3147 .SetShapeFn([](InferenceContext* c) { 3148 return shape_inference::ConcatShape(c, c->num_inputs() - 3); 3149 }) 3150 .Doc(R"doc( 3151 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation. 3152 3153 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 3154 expected to invoke these operators. 3155 )doc"); 3156 #endif 3157 3158 // Deprecated op registrations: 3159 3160 // The following can be deleted after 10mar2017. 3161 REGISTER_OP("BatchMatrixDiag") 3162 .Input("diagonal: T") 3163 .Output("output: T") 3164 .Attr("T: type") 3165 .Deprecated(14, "Use MatrixDiag") 3166 .SetShapeFn(shape_inference::UnknownShape); 3167 REGISTER_OP("BatchMatrixSetDiag") 3168 .Input("input: T") 3169 .Input("diagonal: T") 3170 .Output("output: T") 3171 .Attr("T: type") 3172 .Deprecated(14, "Use MatrixSetDiag") 3173 .SetShapeFn(shape_inference::UnknownShape); 3174 REGISTER_OP("BatchMatrixDiagPart") 3175 .Input("input: T") 3176 .Output("diagonal: T") 3177 .Attr("T: type") 3178 .Deprecated(14, "Use MatrixDiagPart") 3179 .SetShapeFn(shape_inference::UnknownShape); 3180 REGISTER_OP("BatchMatrixBandPart") 3181 .Input("input: T") 3182 .Input("num_lower: int64") 3183 .Input("num_upper: int64") 3184 .Output("band: T") 3185 .Attr("T: type") 3186 .Deprecated(14, "Use MatrixBandPart") 3187 .SetShapeFn(shape_inference::UnknownShape); 3188 3189 } // namespace tensorflow 3190