Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/shape_inference.h"
     19 #include "tensorflow/core/framework/tensor.pb.h"
     20 #include "tensorflow/core/util/mirror_pad_mode.h"
     21 #include "tensorflow/core/util/padding.h"
     22 #include "tensorflow/core/util/strided_slice_op.h"
     23 #include "tensorflow/core/util/tensor_format.h"
     24 
     25 namespace tensorflow {
     26 
     27 using shape_inference::DimensionHandle;
     28 using shape_inference::InferenceContext;
     29 using shape_inference::ShapeHandle;
     30 using shape_inference::UnchangedShape;
     31 
     32 namespace {
     33 
     34 Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
     35                                int32* axis) {
     36   TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
     37   if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
     38     return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
     39                                    -1 * rank_after_pack, ",", rank_after_pack,
     40                                    ")");
     41   }
     42   if (*axis < 0) *axis = (rank_after_pack + *axis);
     43   return Status::OK();
     44 }
     45 
     46 template <typename T>
     47 std::vector<int64> AsInt64(const Tensor* tensor, int64 num_elements) {
     48   std::vector<int64> ret(num_elements);
     49   auto data = tensor->vec<T>();
     50   for (int64 i = 0; i < num_elements; ++i) {
     51     ret[i] = data(i);
     52   }
     53   return ret;
     54 }
     55 
     56 template <typename T>
     57 Status PadKnown(InferenceContext* c, ShapeHandle input,
     58                 const Tensor* paddings_t, int64 num_dims) {
     59   // paddings_t is known.
     60   std::vector<DimensionHandle> dims(num_dims);
     61   auto paddings_data = paddings_t->matrix<T>();
     62   for (int64 i = 0; i < num_dims; ++i) {
     63     const T pad0 = paddings_data(i, 0);
     64     const T pad1 = paddings_data(i, 1);
     65     if (pad0 < 0 || pad1 < 0) {
     66       return errors::InvalidArgument("Paddings must be non-negative");
     67     }
     68     TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
     69   }
     70   c->set_output(0, c->MakeShape(dims));
     71   return Status::OK();
     72 }
     73 
     74 Status PadShapeFn(InferenceContext* c) {
     75   // Paddings is a matrix of [input_rank, 2].
     76   ShapeHandle paddings;
     77   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
     78   DimensionHandle unused;
     79   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
     80 
     81   // n_dim and input.rank are equivalent.
     82   ShapeHandle input = c->input(0);
     83   DimensionHandle n_dim = c->Dim(paddings, 0);
     84   if (c->ValueKnown(n_dim)) {
     85     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
     86   } else if (c->RankKnown(input)) {
     87     TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
     88   }
     89 
     90   const Tensor* paddings_t = c->input_tensor(1);
     91 
     92   // paddings_t is unknown
     93   if (paddings_t == nullptr) {
     94     if (c->ValueKnown(n_dim)) {
     95       // Make output with n_dim unknown dims.
     96       c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
     97     } else {
     98       c->set_output(0, c->UnknownShape());
     99     }
    100     return Status::OK();
    101   }
    102 
    103   const int64 num_dims = paddings_t->shape().dim_size(0);
    104   TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
    105   TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
    106 
    107   if (paddings_t->dtype() == DT_INT32) {
    108     return PadKnown<int32>(c, input, paddings_t, num_dims);
    109   } else {
    110     return PadKnown<int64>(c, input, paddings_t, num_dims);
    111   }
    112 }
    113 
    114 Status TransposeShapeFn(InferenceContext* c) {
    115   ShapeHandle input = c->input(0);
    116   ShapeHandle perm_shape = c->input(1);
    117   const Tensor* perm = c->input_tensor(1);
    118   DimensionHandle perm_elems = c->NumElements(perm_shape);
    119   // If we don't have rank information on the input or value information on
    120   // perm we can't return any shape information, otherwise we have enough
    121   // information to at least find the rank of the output.
    122   if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
    123     c->set_output(0, c->UnknownShape());
    124     return Status::OK();
    125   }
    126 
    127   // Find our value of the rank.
    128   int64 rank;
    129   if (c->RankKnown(input)) {
    130     rank = c->Rank(input);
    131   } else if (c->ValueKnown(perm_elems)) {
    132     rank = c->Value(perm_elems);
    133   } else {
    134     rank = perm->NumElements();
    135   }
    136   if (!c->RankKnown(input) && rank < 2) {
    137     // A permutation array containing a single element is ambiguous. It could
    138     // indicate either a scalar or a 1-dimensional array, both of which the
    139     // transpose op returns unchanged.
    140     c->set_output(0, input);
    141     return Status::OK();
    142   }
    143 
    144   std::vector<DimensionHandle> dims;
    145   dims.resize(rank);
    146   TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
    147   // Ensure that perm is a vector and has rank elements.
    148   TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
    149   TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
    150 
    151   // If we know the rank of the input and the value of perm, we can return
    152   // all shape informantion, otherwise we can only return rank information,
    153   // but no information for the dimensions.
    154   if (perm != nullptr) {
    155     std::vector<int64> data;
    156     if (perm->dtype() == DT_INT32) {
    157       data = AsInt64<int32>(perm, rank);
    158     } else {
    159       data = AsInt64<int64>(perm, rank);
    160     }
    161 
    162     for (int32 i = 0; i < rank; ++i) {
    163       int64 in_idx = data[i];
    164       if (in_idx >= rank) {
    165         return errors::InvalidArgument("perm dim ", in_idx,
    166                                        " is out of range of input rank ", rank);
    167       }
    168       dims[i] = c->Dim(input, in_idx);
    169     }
    170   } else {
    171     for (int i = 0; i < rank; ++i) {
    172       dims[i] = c->UnknownDim();
    173     }
    174   }
    175 
    176   c->set_output(0, c->MakeShape(dims));
    177   return Status::OK();
    178 }
    179 
    180 Status SetOutputShapeForReshape(InferenceContext* c) {
    181   ShapeHandle in = c->input(0);
    182   ShapeHandle out;
    183   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
    184 
    185   if (!c->RankKnown(out)) {
    186     // We have no information about the shape of the output.
    187     c->set_output(0, out);
    188     return Status::OK();
    189   }
    190 
    191   if (c->RankKnown(out) && c->RankKnown(in)) {
    192     // We don't know the number of output elements, but we can try to infer
    193     // the missing dimension.
    194     bool too_many_unknown = false;
    195     int32 out_unknown_idx = -1;
    196 
    197     DimensionHandle known_out_elems = c->NumElements(out);
    198     if (!c->ValueKnown(known_out_elems)) {
    199       known_out_elems = c->MakeDim(1);
    200       for (int32 i = 0; i < c->Rank(out); ++i) {
    201         DimensionHandle dim = c->Dim(out, i);
    202         if (!c->ValueKnown(dim)) {
    203           if (out_unknown_idx >= 0) {
    204             too_many_unknown = true;
    205             break;
    206           }
    207           out_unknown_idx = i;
    208         } else {
    209           TF_RETURN_IF_ERROR(
    210               c->Multiply(known_out_elems, dim, &known_out_elems));
    211         }
    212       }
    213     }
    214     int32 in_unknown_idx = -1;
    215     DimensionHandle known_in_elems = c->NumElements(in);
    216     if (!c->ValueKnown(known_in_elems)) {
    217       known_in_elems = c->MakeDim(1);
    218       for (int32 i = 0; i < c->Rank(in); ++i) {
    219         DimensionHandle dim = c->Dim(in, i);
    220         if (!c->ValueKnown(dim)) {
    221           if (in_unknown_idx >= 0) {
    222             too_many_unknown = true;
    223             break;
    224           }
    225           in_unknown_idx = i;
    226         } else {
    227           TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
    228         }
    229       }
    230     }
    231 
    232     if (!too_many_unknown) {
    233       if (in_unknown_idx < 0 && out_unknown_idx < 0) {
    234         // Just check that the dimensions match.
    235         if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
    236           return errors::InvalidArgument(
    237               "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
    238               " elements to shape ", c->DebugString(out), " (",
    239               c->DebugString(known_out_elems), " elements)");
    240         }
    241       } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
    242                  c->Value(known_out_elems) > 0) {
    243         // Input fully known, infer the one missing output dim
    244         DimensionHandle inferred_dim;
    245         TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
    246                                      true /* evenly_divisible */,
    247                                      &inferred_dim));
    248         TF_RETURN_IF_ERROR(
    249             c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
    250 
    251       } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
    252                  c->Value(known_in_elems) != 0) {
    253         // Output fully known, infer the one missing input dim
    254         DimensionHandle inferred_dim;
    255         TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
    256                                      true /* evenly_divisible */,
    257                                      &inferred_dim));
    258         DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
    259         TF_RETURN_IF_ERROR(
    260             c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
    261       } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
    262         // Exactly one unknown dimension in both input and output. These 2 are
    263         // equal iff the known elements are equal.
    264         if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
    265           DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
    266           TF_RETURN_IF_ERROR(
    267               c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
    268         }
    269       }
    270     }
    271   }
    272   c->set_output(0, out);
    273   return Status::OK();
    274 }
    275 
    276 }  // namespace
    277 
    278 REGISTER_OP("ParallelConcat")
    279     .Input("values: N * T")
    280     .Output("output: T")
    281     .Attr("N: int >= 1")
    282     .Attr("T: type")
    283     .Attr("shape: shape")
    284     .SetShapeFn([](InferenceContext* c) {
    285       // Validate that the shape attr is correct.
    286       PartialTensorShape shape;
    287       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
    288       ShapeHandle passed_shape;
    289       TF_RETURN_IF_ERROR(
    290           c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
    291       if (!c->FullyDefined(passed_shape)) {
    292         return errors::InvalidArgument("shape attr must be fully defined.");
    293       }
    294       ShapeHandle cur;
    295       TF_RETURN_IF_ERROR(c->ReplaceDim(
    296           passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
    297           &cur));
    298       for (int i = 0; i < c->num_inputs(); ++i) {
    299         if (!c->FullyDefined(c->input(i))) {
    300           return errors::InvalidArgument(
    301               "All input shapes must be fully defined.");
    302         }
    303         DimensionHandle unused;
    304         if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
    305           return errors::InvalidArgument("Size of first dimension must be 1.");
    306         }
    307         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
    308                                         "From merging shape ", i,
    309                                         " with other shapes.");
    310       }
    311 
    312       c->set_output(0, passed_shape);
    313 
    314       return Status::OK();
    315     });
    316 
    317 REGISTER_OP("Pack")
    318     .Input("values: N * T")
    319     .Output("output: T")
    320     .Attr("N: int >= 1")
    321     .Attr("T: type")
    322     .Attr("axis: int = 0")
    323     .SetShapeFn([](InferenceContext* c) {
    324       // Validate shapes of all inputs are compatible
    325       ShapeHandle cur = c->input(c->num_inputs() - 1);
    326       for (int i = c->num_inputs() - 2; i >= 0; --i) {
    327         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
    328                                         "From merging shape ", i,
    329                                         " with other shapes.");
    330       }
    331       if (!c->RankKnown(cur)) {
    332         c->set_output(0, c->UnknownShape());
    333         return Status::OK();
    334       }
    335       // Determine the axis that will be added, converting from negative
    336       // axes to a positive point per negative indexing rules.
    337       int32 rank = c->Rank(cur);
    338       int32 axis;
    339       TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
    340 
    341       // Copy all dimensions over, inserting a dimension of value #inputs
    342       // at <axis>.
    343       std::vector<DimensionHandle> dims;
    344       int index = 0;
    345       while (index < axis) dims.push_back(c->Dim(cur, index++));
    346       dims.push_back(c->MakeDim(c->num_inputs()));
    347       while (index < rank) dims.push_back(c->Dim(cur, index++));
    348 
    349       c->set_output(0, c->MakeShape(dims));
    350       for (int i = 0; i < c->num_inputs(); ++i) {
    351         auto* shape_and_type = c->input_handle_shapes_and_types(i);
    352         if (shape_and_type) {
    353           if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
    354             c->set_output_handle_shapes_and_types(
    355                 0, std::vector<shape_inference::ShapeAndType>({}));
    356             break;
    357           }
    358         }
    359       }
    360       return Status::OK();
    361     });
    362 
    363 REGISTER_OP("DeepCopy")
    364     .Input("x: T")
    365     .Output("y: T")
    366     .Attr("T: type")
    367     .SetIsStateful()
    368     .SetShapeFn(UnchangedShape);
    369 
    370 REGISTER_OP("InplaceUpdate")
    371     .Input("x: T")
    372     .Input("i: int32")
    373     .Input("v: T")
    374     .Output("y: T")
    375     .Attr("T: type")
    376     .SetShapeFn(UnchangedShape);
    377 
    378 REGISTER_OP("InplaceAdd")
    379     .Input("x: T")
    380     .Input("i: int32")
    381     .Input("v: T")
    382     .Output("y: T")
    383     .Attr("T: type")
    384     .SetShapeFn(UnchangedShape);
    385 
    386 REGISTER_OP("InplaceSub")
    387     .Input("x: T")
    388     .Input("i: int32")
    389     .Input("v: T")
    390     .Output("y: T")
    391     .Attr("T: type")
    392     .SetShapeFn(UnchangedShape);
    393 
    394 REGISTER_OP("Empty")
    395     .Input("shape: int32")
    396     .Output("output: dtype")
    397     .Attr("dtype: type")
    398     .Attr("init: bool = false")
    399     .SetIsStateful()
    400     .SetShapeFn([](InferenceContext* c) {
    401       ShapeHandle out;
    402       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
    403       c->set_output(0, out);
    404       return Status::OK();
    405     });
    406 
    407 // --------------------------------------------------------------------------
    408 REGISTER_OP("Unpack")
    409     .Input("value: T")
    410     .Output("output: num * T")
    411     .Attr("num: int >= 0")
    412     .Attr("T: type")
    413     .Attr("axis: int = 0")
    414     .SetShapeFn([](InferenceContext* c) {
    415       ShapeHandle s = c->input(0);
    416       ShapeHandle out;
    417       if (c->RankKnown(s)) {
    418         // Determine the axis that will be removed, converting from negative
    419         // axes to a positive point per negative indexing rules.
    420         int32 rank = c->Rank(s);
    421         int32 axis;
    422         TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
    423 
    424         // The axis dim matches the number of outputs.
    425         DimensionHandle unused;
    426         TF_RETURN_IF_ERROR(
    427             c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
    428 
    429         // Copy all dimensions, removing the <axis> dimension.
    430         std::vector<DimensionHandle> dims;
    431         for (int i = 0; i < rank; ++i) {
    432           if (i != axis) dims.push_back(c->Dim(s, i));
    433         }
    434         out = c->MakeShape(dims);
    435       } else {
    436         // All outputs are the same shape, but it's not known.
    437         out = c->UnknownShape();
    438       }
    439       for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
    440       return Status::OK();
    441     });
    442 
    443 REGISTER_OP("UnravelIndex")
    444     .Input("indices: Tidx")
    445     .Input("dims: Tidx")
    446     .Output("output: Tidx")
    447     .Attr("Tidx: {int32, int64} = DT_INT32")
    448     .SetShapeFn([](InferenceContext* c) {
    449       ShapeHandle indices = c->input(0);
    450       ShapeHandle dims;
    451       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
    452       if (c->RankKnown(indices) && c->Rank(indices) == 0) {
    453         c->set_output(0, c->Vector(c->Dim(dims, 0)));
    454       } else if (c->RankKnown(indices)) {
    455         c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
    456       } else {
    457         c->set_output(0, c->UnknownShape());
    458       }
    459       return Status::OK();
    460     });
    461 
    462 REGISTER_OP("BroadcastTo")
    463     .Input("input: T")
    464     .Input("shape: Tidx")
    465     .Output("output: T")
    466     .Attr("T: type")
    467     .Attr("Tidx: {int32, int64} = DT_INT32")
    468     .SetShapeFn([](InferenceContext* c) {
    469       ShapeHandle shape_in = c->input(1);
    470       TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
    471       ShapeHandle out;
    472       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
    473       if (!c->RankKnown(out)) {
    474         // We have no information about the shape of the output.
    475         c->set_output(0, out);
    476         return Status::OK();
    477       }
    478 
    479       ShapeHandle in = c->input(0);
    480       if (!c->RankKnown(in)) {
    481         // We have no information about the shape of the input,
    482         // nothing to do here.
    483         c->set_output(0, out);
    484         return Status::OK();
    485       }
    486       int out_rank = c->Rank(out);
    487       TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
    488       int in_rank = c->Rank(in);
    489       for (int i = 0; i < in_rank; ++i) {
    490         auto in_dim = c->Dim(in, in_rank - i - 1);
    491         if (c->Value(in_dim) > 1) {
    492           // If the input dimension is greater than 1 then the output dimension
    493           // must be equal to it, since we only broadcast "from left to right".
    494           auto out_dim = c->Dim(out, out_rank - i - 1);
    495           TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
    496           TF_RETURN_IF_ERROR(
    497               c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
    498         }
    499       }
    500       c->set_output(0, out);
    501       return Status::OK();
    502     });
    503 
    504 // --------------------------------------------------------------------------
    505 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
    506 // in the N == 1 case to remove the node.
    507 REGISTER_OP("Concat")
    508     .Input("concat_dim: int32")
    509     .Input("values: N * T")
    510     .Output("output: T")
    511     .Attr("N: int >= 2")
    512     .Attr("T: type")
    513     .SetShapeFn([](InferenceContext* c) {
    514       return shape_inference::ConcatShape(c, c->num_inputs() - 1);
    515     });
    516 
    517 REGISTER_OP("ConcatV2")
    518     .Input("values: N * T")
    519     .Input("axis: Tidx")
    520     .Output("output: T")
    521     .Attr("N: int >= 2")
    522     .Attr("T: type")
    523     .Attr("Tidx: {int32, int64} = DT_INT32")
    524     .SetShapeFn(shape_inference::ConcatV2Shape);
    525 
    526 // TODO(vivek.v.rane (at) intel.com): Prefix the op names with underscore if the ops
    527 // are not to be made user-accessible.
    528 #ifdef INTEL_MKL
    529 REGISTER_OP("_MklConcatV2")
    530     .Input("values: N * T")
    531     .Input("axis: Tidx")
    532     .Input("mkl_values: N * uint8")
    533     .Input("mkl_axis: uint8")
    534     .Output("output: T")
    535     .Output("mkl_output: uint8")
    536     .Attr("N: int >= 2")
    537     .Attr("T: type")
    538     .Attr("Tidx: {int32, int64} = DT_INT32")
    539     .SetShapeFn(shape_inference::ConcatV2Shape)
    540     .Doc(R"doc(
    541 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
    542 
    543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
    544 expected to invoke these operators.
    545 )doc");
    546 #endif
    547 
    548 REGISTER_OP("ConcatOffset")
    549     .Input("concat_dim: int32")
    550     .Input("shape: N * int32")
    551     .Output("offset: N * int32")
    552     .Attr("N: int >= 2")
    553     .SetShapeFn([](InferenceContext* c) {
    554       for (int i = 1; i < c->num_inputs(); ++i) {
    555         c->set_output(i - 1, c->input(i));
    556       }
    557       return Status::OK();
    558     });
    559 
    560 // --------------------------------------------------------------------------
    561 REGISTER_OP("Split")
    562     .Input("split_dim: int32")
    563     .Input("value: T")
    564     .Output("output: num_split * T")
    565     .Attr("num_split: int >= 1")
    566     .Attr("T: type")
    567     .SetShapeFn([](InferenceContext* c) {
    568       DimensionHandle split_dimension;
    569       ShapeHandle input = c->input(1);
    570       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
    571           0, c->Rank(input), &split_dimension));
    572       int num_split = c->num_outputs();
    573       ShapeHandle out;
    574       if (!c->ValueKnown(split_dimension)) {
    575         if (c->RankKnown(input)) {
    576           out = c->UnknownShapeOfRank(c->Rank(input));
    577         } else {
    578           out = c->UnknownShape();
    579         }
    580       } else {
    581         int64 split_dim = c->Value(split_dimension);
    582         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
    583         DimensionHandle split_dim_size;
    584         TF_RETURN_WITH_CONTEXT_IF_ERROR(
    585             c->Divide(c->Dim(input, split_dim), num_split,
    586                       true /* evenly_divisible */, &split_dim_size),
    587             "Number of ways to split should evenly divide the split dimension");
    588         TF_RETURN_IF_ERROR(
    589             c->ReplaceDim(input, split_dim, split_dim_size, &out));
    590       }
    591       for (int i = 0; i < num_split; ++i) c->set_output(i, out);
    592       return Status::OK();
    593     });
    594 
    595 REGISTER_OP("SplitV")
    596     .Input("value: T")
    597     .Input("size_splits: Tlen")
    598     .Input("split_dim: int32")
    599     .Output("output: num_split * T")
    600     .Attr("num_split: int >= 1")
    601     .Attr("T: type")
    602     .Attr("Tlen: {int32, int64} = DT_INT64")
    603     .SetShapeFn([](InferenceContext* c) {
    604       DimensionHandle split_dimension;
    605       ShapeHandle input = c->input(0);
    606       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
    607           2, c->Rank(input), &split_dimension));
    608       int32 num_outputs = c->num_outputs();
    609       int32 rank = c->Rank(input);
    610       ShapeHandle output_shape;
    611       const Tensor* size_splits = c->input_tensor(1);
    612       if (rank == InferenceContext::kUnknownRank) {
    613         // If the rank of input tensor is unknown, then return unknown shapes.
    614         // Note that the shape of each output can be different.
    615         for (int i = 0; i < num_outputs; ++i) {
    616           c->set_output(i, c->UnknownShape());
    617         }
    618       } else if (rank == 0) {
    619         // Throw error if input is a scalar.
    620         return errors::InvalidArgument("Can't split scalars");
    621       } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
    622         // If split dimension is known, but the sizes are unknown, then
    623         // only the split dimension is unknown
    624         output_shape = input;
    625         for (int i = 0; i < num_outputs; ++i) {
    626           TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
    627                                            c->Value(split_dimension),
    628                                            c->UnknownDim(), &output_shape));
    629           c->set_output(i, output_shape);
    630         }
    631       } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
    632         // If split dimension or tensor containing the split sizes is unknown,
    633         // then return unknown shapes of same rank as input. Note that each
    634         // output shape can be different since splitv doesn't always split
    635         // tensors evenly.
    636         for (int i = 0; i < num_outputs; ++i) {
    637           c->set_output(i, c->UnknownShapeOfRank(rank));
    638         }
    639       } else {
    640         // Determine the output shape if split dimension and split sizes are
    641         // known.
    642         int64 split_dim = c->Value(split_dimension);
    643         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
    644         std::vector<int64> data;
    645         if (size_splits->dtype() == DT_INT32) {
    646           data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
    647         } else {
    648           data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
    649         }
    650         if (num_outputs != data.size()) {
    651           return errors::InvalidArgument(
    652               "Length of size_splits should be equal to num_outputs");
    653         }
    654         int64_t total_size = 0;
    655         bool has_neg_one = false;
    656         for (const auto size : data) {
    657           if (size == -1) {
    658             if (has_neg_one) {
    659               return errors::InvalidArgument(
    660                   "size_splits can only have one -1");
    661             }
    662             has_neg_one = true;
    663           } else {
    664             total_size += size;
    665           }
    666         }
    667         auto split_dim_size = c->Value(c->Dim(input, split_dim));
    668         // If the sizes of the splits are known, then
    669         // make sure that the sizes add up to the expected
    670         // dimension size, with the possibility of a -1.
    671         // Specify the full output shapes.
    672         for (int i = 0; i < num_outputs; ++i) {
    673           auto size = data[i];
    674           if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
    675             size = split_dim_size - total_size;
    676           }
    677           TF_RETURN_IF_ERROR(
    678               c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
    679           c->set_output(i, output_shape);
    680         }
    681         if (c->ValueKnown(split_dim_size)) {
    682           if (has_neg_one ? total_size > split_dim_size
    683                           : total_size != split_dim_size) {
    684             return errors::InvalidArgument(
    685                 "can't split axis of size ", split_dim_size,
    686                 " into pieces of size [", str_util::Join(data, ","), "]");
    687           }
    688         }
    689       }
    690 
    691       return Status::OK();
    692     });
    693 
    694 // --------------------------------------------------------------------------
    695 REGISTER_OP("Const")
    696     .Output("output: dtype")
    697     .Attr("value: tensor")
    698     .Attr("dtype: type")
    699     .SetShapeFn([](InferenceContext* c) {
    700       const TensorProto* proto = nullptr;
    701       TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
    702       TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
    703       TensorShape shape(proto->tensor_shape());
    704       std::vector<DimensionHandle> dims;
    705       dims.reserve(shape.dims());
    706       for (int i = 0; i < shape.dims(); ++i) {
    707         dims.push_back(c->MakeDim(shape.dim_size(i)));
    708       }
    709       c->set_output(0, c->MakeShape(dims));
    710       return Status::OK();
    711     });
    712 
    713 // Returns a constant tensor on the host.  Useful for writing C++ tests
    714 // and benchmarks which run on GPU but require arguments pinned to the host.
    715 // Used by test::graph::HostConstant.
    716 // value: Attr `value` is the tensor to return.
    717 REGISTER_OP("HostConst")
    718     .Output("output: dtype")
    719     .Attr("value: tensor")
    720     .Attr("dtype: type")
    721     .SetShapeFn(shape_inference::UnknownShape);
    722 
    723 // --------------------------------------------------------------------------
    724 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
    725 // into memmapped format.
    726 REGISTER_OP("ImmutableConst")
    727     .Attr("dtype: type")
    728     .Attr("shape: shape")
    729     .Attr("memory_region_name: string")
    730     .Output("tensor: dtype")
    731     .SetShapeFn(shape_inference::ExplicitShape);
    732 
    733 REGISTER_OP("GuaranteeConst")
    734     .Input("input: T")
    735     .Output("output: T")
    736     .Attr("T: type")
    737     .SetShapeFn([](shape_inference::InferenceContext* c) {
    738       return UnchangedShape(c);
    739     })
    740     // We don't want this to be optimized away.
    741     .SetIsStateful();
    742 
    743 // --------------------------------------------------------------------------
    744 REGISTER_OP("ZerosLike")
    745     .Input("x: T")
    746     .Output("y: T")
    747     .Attr("T: type")
    748     .SetShapeFn(shape_inference::UnchangedShape);
    749 
    750 // --------------------------------------------------------------------------
    751 REGISTER_OP("OnesLike")
    752     .Input("x: T")
    753     .Output("y: T")
    754     .Attr(
    755         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
    756         "int64, complex64, complex128, bool}")
    757     .SetShapeFn(shape_inference::UnchangedShape);
    758 
    759 // --------------------------------------------------------------------------
    760 REGISTER_OP("Diag")
    761     .Input("diagonal: T")
    762     .Output("output: T")
    763     .Attr(
    764         "T: {bfloat16, half, float, double, int32, int64, complex64, "
    765         "complex128}")
    766     .SetShapeFn([](InferenceContext* c) {
    767       ShapeHandle in = c->input(0);
    768       TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
    769       // Output shape is original concatenated with itself.
    770       ShapeHandle out;
    771       TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
    772       c->set_output(0, out);
    773       return Status::OK();
    774     });
    775 
    776 // --------------------------------------------------------------------------
    777 REGISTER_OP("DiagPart")
    778     .Input("input: T")
    779     .Output("diagonal: T")
    780     .Attr(
    781         "T: {bfloat16, half, float, double, int32, int64, complex64, "
    782         "complex128}")
    783     .SetShapeFn([](InferenceContext* c) {
    784       ShapeHandle in = c->input(0);
    785       if (!c->RankKnown(in)) {
    786         c->set_output(0, c->UnknownShape());
    787         return Status::OK();
    788       }
    789       // Rank must be even, and result will have rank <rank/2>.
    790       const int32 rank = c->Rank(in);
    791       if ((rank % 2) != 0 || rank <= 0) {
    792         return errors::InvalidArgument(
    793             "Input must have even and non-zero rank, input rank is ", rank);
    794       }
    795       const int32 mid = rank / 2;
    796 
    797       // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
    798       std::vector<DimensionHandle> dims(mid);
    799       for (int i = 0; i < mid; ++i) {
    800         TF_RETURN_IF_ERROR(
    801             c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
    802       }
    803       c->set_output(0, c->MakeShape(dims));
    804       return Status::OK();
    805     });
    806 
    807 // --------------------------------------------------------------------------
    808 REGISTER_OP("MatrixDiag")
    809     .Input("diagonal: T")
    810     .Output("output: T")
    811     .Attr("T: type")
    812     .SetShapeFn([](InferenceContext* c) {
    813       ShapeHandle in;
    814       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
    815       if (!c->RankKnown(in)) {
    816         c->set_output(0, c->UnknownShape());
    817         return Status::OK();
    818       }
    819       const int32 rank = c->Rank(in);
    820       ShapeHandle out;
    821       TF_RETURN_IF_ERROR(
    822           c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
    823       c->set_output(0, out);
    824       return Status::OK();
    825     });
    826 
    827 // --------------------------------------------------------------------------
    828 REGISTER_OP("MatrixSetDiag")
    829     .Input("input: T")
    830     .Input("diagonal: T")
    831     .Output("output: T")
    832     .Attr("T: type")
    833     .SetShapeFn([](InferenceContext* c) {
    834       ShapeHandle input;
    835       ShapeHandle diag;
    836       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    837       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
    838       if (c->RankKnown(input)) {
    839         TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
    840       }
    841       DimensionHandle smallest_dim;
    842       TF_RETURN_IF_ERROR(
    843           c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
    844       TF_RETURN_IF_ERROR(
    845           c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
    846 
    847       ShapeHandle output = input;
    848       if (c->RankKnown(diag) && !c->FullyDefined(input)) {
    849         // Try to infer parts of shape from diag.
    850         ShapeHandle diag_prefix;
    851         TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix));
    852         TF_RETURN_IF_ERROR(
    853             c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag));
    854         TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
    855       }
    856       c->set_output(0, output);
    857       return Status::OK();
    858     });
    859 
    860 // --------------------------------------------------------------------------
    861 REGISTER_OP("MatrixDiagPart")
    862     .Input("input: T")
    863     .Output("diagonal: T")
    864     .Attr("T: type")
    865     .SetShapeFn([](InferenceContext* c) {
    866       ShapeHandle in;
    867       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
    868       if (!c->RankKnown(in)) {
    869         c->set_output(0, c->UnknownShape());
    870         return Status::OK();
    871       }
    872       const int32 rank = c->Rank(in);
    873       std::vector<DimensionHandle> dims;
    874       dims.reserve(rank - 2);
    875       for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
    876 
    877       DimensionHandle min_dim;
    878       TF_RETURN_IF_ERROR(
    879           c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
    880       dims.push_back(min_dim);
    881       c->set_output(0, c->MakeShape(dims));
    882       return Status::OK();
    883     });
    884 
    885 // --------------------------------------------------------------------------
    886 REGISTER_OP("MatrixBandPart")
    887     .Input("input: T")
    888     .Input("num_lower: Tindex")
    889     .Input("num_upper: Tindex")
    890     .Output("band: T")
    891     .Attr("T: type")
    892     .Attr("Tindex: {int32, int64} = DT_INT64")
    893     .SetShapeFn(shape_inference::UnchangedShape);
    894 
    895 // --------------------------------------------------------------------------
    896 REGISTER_OP("Reverse")
    897     .Input("tensor: T")
    898     .Input("dims: bool")
    899     .Output("output: T")
    900     .Attr(
    901         "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
    902         "float, double, complex64, complex128, string}")
    903     .SetShapeFn([](InferenceContext* c) {
    904       ShapeHandle input = c->input(0);
    905       ShapeHandle dims;
    906       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
    907       DimensionHandle dims_dim = c->Dim(dims, 0);
    908       if (c->ValueKnown(dims_dim)) {
    909         TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
    910       }
    911       if (c->Rank(input) > 8) {
    912         return errors::InvalidArgument(
    913             "reverse does not work on tensors with more than 8 dimensions");
    914       }
    915       c->set_output(0, input);
    916       return Status::OK();
    917     });
    918 
    919 // --------------------------------------------------------------------------
    920 REGISTER_OP("ReverseV2")
    921     .Input("tensor: T")
    922     .Input("axis: Tidx")
    923     .Output("output: T")
    924     .Attr("Tidx: {int32, int64} = DT_INT32")
    925     .Attr(
    926         "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
    927         "float, double, complex64, complex128, string}")
    928     .SetShapeFn([](InferenceContext* c) {
    929       ShapeHandle input = c->input(0);
    930       ShapeHandle axis;
    931       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
    932       if (c->Rank(input) > 8) {
    933         return errors::InvalidArgument(
    934             "reverse does not work on tensors with more than 8 dimensions");
    935       }
    936       const Tensor* axis_tensor = c->input_tensor(1);
    937       if (axis_tensor != nullptr && c->RankKnown(input)) {
    938         int32 rank = c->Rank(input);
    939         std::vector<int64> axis_value;
    940         if (axis_tensor->dtype() == DT_INT32) {
    941           axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
    942         } else {
    943           axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
    944         }
    945         std::vector<bool> axes_dense(c->Rank(input), false);
    946         for (int i = 0; i < axis_value.size(); i++) {
    947           int64 canonical_axis =
    948               axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
    949           if (canonical_axis < 0 || canonical_axis >= rank) {
    950             return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
    951                                            " is out of valid range [", 0, ", ",
    952                                            rank - 1);
    953           }
    954           if (axes_dense[canonical_axis]) {
    955             return errors::InvalidArgument("axis ", canonical_axis,
    956                                            " specified more than once.");
    957           }
    958           axes_dense[canonical_axis] = true;
    959         }
    960       }
    961       c->set_output(0, input);
    962       return Status::OK();
    963     });
    964 
    965 // --------------------------------------------------------------------------
    966 REGISTER_OP("EditDistance")
    967     .Input("hypothesis_indices: int64")
    968     .Input("hypothesis_values: T")
    969     .Input("hypothesis_shape: int64")
    970     .Input("truth_indices: int64")
    971     .Input("truth_values: T")
    972     .Input("truth_shape: int64")
    973     .Attr("normalize: bool = true")
    974     .Attr("T: type")
    975     .Output("output: float")
    976     .SetShapeFn([](InferenceContext* c) {
    977       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
    978           c, c->input(0), c->input(1), c->input(2)));
    979       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
    980           c, c->input(3), c->input(4), c->input(5)));
    981       const Tensor* hypothesis_shape_t = c->input_tensor(2);
    982       const Tensor* truth_shape_t = c->input_tensor(5);
    983       if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
    984         // We need to know the runtime shape of the two tensors,
    985         // or else the output shape is unknown.
    986         return shape_inference::UnknownShape(c);
    987       }
    988 
    989       if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
    990         return errors::InvalidArgument(
    991             "Num elements of hypothesis_shape does not match truth_shape: ",
    992             hypothesis_shape_t->NumElements(), " vs. ",
    993             truth_shape_t->NumElements());
    994       }
    995 
    996       auto h_values = hypothesis_shape_t->flat<int64>();
    997       auto t_values = truth_shape_t->flat<int64>();
    998       std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
    999       for (int i = 0; i < dims.size(); ++i) {
   1000         dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
   1001       }
   1002 
   1003       c->set_output(0, c->MakeShape(dims));
   1004       return Status::OK();
   1005     });
   1006 
   1007 // --------------------------------------------------------------------------
   1008 REGISTER_OP("Fill")
   1009     .Input("dims: index_type")
   1010     .Input("value: T")
   1011     .Output("output: T")
   1012     .Attr("T: type")
   1013     .Attr("index_type: {int32, int64} = DT_INT32")
   1014     .SetShapeFn([](InferenceContext* c) {
   1015       DataType index_type = DT_INT32;
   1016       Status s = c->GetAttr("index_type", &index_type);
   1017       if (!s.ok() && s.code() != error::NOT_FOUND) {
   1018         return s;
   1019       }
   1020       ShapeHandle unused;
   1021       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
   1022       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1023 
   1024       const Tensor* t = c->input_tensor(0);
   1025       if (t != nullptr) {
   1026         for (int i = 0; i < t->NumElements(); ++i) {
   1027           if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
   1028               (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
   1029             return errors::InvalidArgument("Fill dimensions must be >= 0");
   1030           }
   1031         }
   1032       }
   1033 
   1034       ShapeHandle out;
   1035       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
   1036       c->set_output(0, out);
   1037 
   1038       auto* shape_and_type = c->input_handle_shapes_and_types(1);
   1039       if (shape_and_type) {
   1040         c->set_output_handle_shapes_and_types(0, *shape_and_type);
   1041       }
   1042 
   1043       return Status::OK();
   1044     });
   1045 
   1046 // --------------------------------------------------------------------------
   1047 REGISTER_OP("_ParallelConcatStart")
   1048     .Output("output: dtype")
   1049     .Attr("shape: shape")
   1050     .Attr("dtype: type")
   1051     .SetIsStateful()
   1052     .SetShapeFn(shape_inference::ExplicitShape)
   1053     .Doc(R"doc(
   1054 Creates an empty Tensor with shape `shape` and type `dtype`.
   1055 
   1056 The memory can optionally be initialized. This is usually useful in
   1057 conjunction with inplace operations.
   1058 
   1059 shape: 1-D `Tensor` indicating the shape of the output.
   1060 dtype: The element type of the returned tensor.
   1061 output: An empty Tensor of the specified type.
   1062 )doc");
   1063 
   1064 // --------------------------------------------------------------------------
   1065 REGISTER_OP("_ParallelConcatUpdate")
   1066     .Input("value: T")
   1067     .Input("update: T")
   1068     .Output("output: T")
   1069     .Attr("T: type")
   1070     .Attr("loc: int")
   1071     .SetShapeFn(shape_inference::UnchangedShape)
   1072     .Doc(R"doc(
   1073 Updates input `value` at `loc` with `update`.
   1074 
   1075 If you use this function you will almost certainly want to add
   1076 a control dependency as done in the implementation of parallel_stack to
   1077 avoid race conditions.
   1078 
   1079 value: A `Tensor` object that will be updated in-place.
   1080 loc: A scalar indicating the index of the first dimension such that
   1081          value[loc, :] is updated.
   1082 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
   1083         otherwise of rank equal to `value` that contains the new values
   1084         for `value`.
   1085 output: `value` that has been updated accordingly.
   1086 )doc");
   1087 
   1088 // --------------------------------------------------------------------------
   1089 REGISTER_OP("Gather")
   1090     .Input("params: Tparams")
   1091     .Input("indices: Tindices")
   1092     .Attr("validate_indices: bool = true")
   1093     .Output("output: Tparams")
   1094     .Attr("Tparams: type")
   1095     .Attr("Tindices: {int32,int64}")
   1096     .SetShapeFn([](InferenceContext* c) {
   1097       ShapeHandle unused;
   1098       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
   1099       ShapeHandle params_subshape;
   1100       TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
   1101       ShapeHandle indices_shape = c->input(1);
   1102       ShapeHandle out;
   1103       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
   1104       c->set_output(0, out);
   1105       return Status::OK();
   1106     });
   1107 
   1108 // --------------------------------------------------------------------------
   1109 REGISTER_OP("GatherV2")
   1110     .Input("params: Tparams")
   1111     .Input("indices: Tindices")
   1112     .Input("axis: Taxis")
   1113     .Output("output: Tparams")
   1114     .Attr("Tparams: type")
   1115     .Attr("Tindices: {int32,int64}")
   1116     .Attr("Taxis: {int32,int64}")
   1117     .SetShapeFn([](InferenceContext* c) {
   1118       ShapeHandle params_shape;
   1119       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &params_shape));
   1120 
   1121       ShapeHandle indices_shape = c->input(1);
   1122       ShapeHandle unused_axis_shape;
   1123       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
   1124       const Tensor* axis_t = c->input_tensor(2);
   1125 
   1126       // If axis is unknown, we can only infer that the result is params_rank +
   1127       // indices_rank - 1.
   1128       if (axis_t == nullptr) {
   1129         if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
   1130           c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
   1131                                                  c->Rank(indices_shape) - 1));
   1132         } else {
   1133           c->set_output(0, c->UnknownShape());
   1134         }
   1135         return Status::OK();
   1136       }
   1137 
   1138       // Note, axis can be negative.
   1139       int64 axis = 0;
   1140       if (axis_t->dtype() == DT_INT32) {
   1141         axis = axis_t->scalar<int32>()();
   1142       } else {
   1143         axis = axis_t->scalar<int64>()();
   1144       }
   1145 
   1146       // Check that params has rank of at least axis + 1.
   1147       ShapeHandle unused;
   1148       TF_RETURN_IF_ERROR(c->WithRankAtLeast(
   1149           params_shape, axis < 0 ? -axis : axis + 1, &unused));
   1150 
   1151       ShapeHandle params_outer_subshape;
   1152       TF_RETURN_IF_ERROR(
   1153           c->Subshape(params_shape, 0, axis, &params_outer_subshape));
   1154 
   1155       ShapeHandle out;
   1156       TF_RETURN_IF_ERROR(
   1157           c->Concatenate(params_outer_subshape, indices_shape, &out));
   1158 
   1159       // Slice from axis + 1 to the end of params_shape to collect the inner
   1160       // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
   1161       // we slice from 0 to the end of shape. Subshape() handles all other
   1162       // out-of-bounds checking.
   1163       if (axis != -1) {
   1164         ShapeHandle params_inner_subshape;
   1165         TF_RETURN_IF_ERROR(
   1166             c->Subshape(params_shape, axis + 1, &params_inner_subshape));
   1167         TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
   1168       }
   1169 
   1170       c->set_output(0, out);
   1171       return Status::OK();
   1172     });
   1173 
   1174 // --------------------------------------------------------------------------
   1175 REGISTER_OP("GatherNd")
   1176     .Input("params: Tparams")
   1177     .Input("indices: Tindices")
   1178     .Output("output: Tparams")
   1179     .Attr("Tparams: type")
   1180     .Attr("Tindices: {int32,int64}")
   1181     .SetShapeFn([](InferenceContext* c) {
   1182       ShapeHandle params = c->input(0);
   1183       ShapeHandle indices;
   1184       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
   1185       DimensionHandle r_dim = c->Dim(indices, -1);
   1186 
   1187       if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
   1188         c->set_output(0, c->UnknownShape());
   1189         return Status::OK();
   1190       }
   1191 
   1192       if (c->Value(r_dim) > c->Rank(params)) {
   1193         return errors::InvalidArgument(
   1194             "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
   1195             c->DebugString(indices),
   1196             " and params shape: ", c->DebugString(params));
   1197       }
   1198 
   1199       // Remove r_dim from indices to get output.
   1200       ShapeHandle indices_slice;
   1201       ShapeHandle params_slice;
   1202       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
   1203       TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
   1204       ShapeHandle out;
   1205       TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
   1206       c->set_output(0, out);
   1207       return Status::OK();
   1208     });
   1209 
   1210 // --------------------------------------------------------------------------
   1211 REGISTER_OP("Identity")
   1212     .Input("input: T")
   1213     .Output("output: T")
   1214     .Attr("T: type")
   1215     .SetShapeFn(shape_inference::UnchangedShape);
   1216 
   1217 REGISTER_OP("Snapshot")
   1218     .Input("input: T")
   1219     .Output("output: T")
   1220     .Attr("T: type")
   1221     .SetShapeFn(shape_inference::UnchangedShape);
   1222 
   1223 #ifdef INTEL_MKL
   1224 REGISTER_OP("_MklIdentity")
   1225     .Input("input: T")
   1226     .Input("mkl_input: uint8")
   1227     .Output("output: T")
   1228     .Output("mkl_output: uint8")
   1229     .Attr("T: type")
   1230     .SetShapeFn(shape_inference::UnchangedShape)
   1231     .Doc(R"Doc( Mkl implementation of IdentityOp
   1232 )Doc");
   1233 #endif
   1234 
   1235 REGISTER_OP("IdentityN")
   1236     .Input("input: T")
   1237     .Output("output: T")
   1238     .Attr("T: list(type)")
   1239     .SetShapeFn([](shape_inference::InferenceContext* c) {
   1240       std::vector<ShapeHandle> input;
   1241       TF_RETURN_IF_ERROR(c->input("input", &input));
   1242       TF_RETURN_IF_ERROR(c->set_output("output", input));
   1243       return Status::OK();
   1244     });
   1245 
   1246 // --------------------------------------------------------------------------
   1247 REGISTER_OP("RefIdentity")
   1248     .Input("input: Ref(T)")
   1249     .Output("output: Ref(T)")
   1250     .Attr("T: type")
   1251     .SetShapeFn(shape_inference::UnchangedShape)
   1252     .SetAllowsUninitializedInput();
   1253 
   1254 // --------------------------------------------------------------------------
   1255 REGISTER_OP("DebugGradientIdentity")
   1256     .Input("input: T")
   1257     .Output("output: T")
   1258     .Attr("T: type")
   1259     .SetShapeFn(shape_inference::UnchangedShape)
   1260     .SetAllowsUninitializedInput();
   1261 
   1262 REGISTER_OP("DebugGradientRefIdentity")
   1263     .Input("input: Ref(T)")
   1264     .Output("output: Ref(T)")
   1265     .Attr("T: type")
   1266     .SetShapeFn(shape_inference::UnchangedShape)
   1267     .SetAllowsUninitializedInput();
   1268 
   1269 // --------------------------------------------------------------------------
   1270 REGISTER_OP("StopGradient")
   1271     .Input("input: T")
   1272     .Output("output: T")
   1273     .Attr("T: type")
   1274     .SetShapeFn(shape_inference::UnchangedShape);
   1275 
   1276 REGISTER_OP("PreventGradient")
   1277     .Input("input: T")
   1278     .Output("output: T")
   1279     .Attr("T: type")
   1280     .Attr("message: string = ''")
   1281     .SetShapeFn(shape_inference::UnchangedShape);
   1282 
   1283 // --------------------------------------------------------------------------
   1284 REGISTER_OP("CheckNumerics")
   1285     .Input("tensor: T")
   1286     .Output("output: T")
   1287     .Attr("T: {bfloat16, half, float, double}")
   1288     .Attr("message: string")
   1289     .SetShapeFn(shape_inference::UnchangedShape);
   1290 
   1291 // --------------------------------------------------------------------------
   1292 REGISTER_OP("Reshape")
   1293     .Input("tensor: T")
   1294     .Input("shape: Tshape")
   1295     .Output("output: T")
   1296     .Attr("T: type")
   1297     .Attr("Tshape: {int32, int64} = DT_INT32")
   1298     .SetShapeFn([](InferenceContext* c) {
   1299       return SetOutputShapeForReshape(c);
   1300     });
   1301 
   1302 #ifdef INTEL_MKL
   1303 REGISTER_OP("_MklReshape")
   1304     .Input("tensor: T")
   1305     .Input("shape: Tshape")
   1306     .Input("mkl_tensor: uint8")
   1307     .Input("mkl_shape: uint8")
   1308     .Output("output: T")
   1309     .Output("mkl_output: uint8")
   1310     .Attr("T: type")
   1311     .Attr("Tshape: {int32, int64} = DT_INT32")
   1312     .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
   1313     .Doc(R"Doc( MKL implementation of ReshapeOp.
   1314 )Doc");
   1315 #endif  // INTEL_MKL
   1316 
   1317 // --------------------------------------------------------------------------
   1318 REGISTER_OP("InvertPermutation")
   1319     .Input("x: T")
   1320     .Output("y: T")
   1321     .Attr("T: {int32, int64} = DT_INT32")
   1322     .SetShapeFn([](InferenceContext* c) {
   1323       ShapeHandle x;
   1324       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
   1325       c->set_output(0, x);
   1326       return Status::OK();
   1327     });
   1328 
   1329 // --------------------------------------------------------------------------
   1330 REGISTER_OP("Transpose")
   1331     .Input("x: T")
   1332     .Input("perm: Tperm")
   1333     .Output("y: T")
   1334     .Attr("T: type")
   1335     .Attr("Tperm: {int32, int64} = DT_INT32")
   1336     .SetShapeFn(TransposeShapeFn);
   1337 
   1338 // --------------------------------------------------------------------------
   1339 REGISTER_OP("ConjugateTranspose")
   1340     .Input("x: T")
   1341     .Input("perm: Tperm")
   1342     .Output("y: T")
   1343     .Attr("T: type")
   1344     .Attr("Tperm: {int32, int64} = DT_INT32")
   1345     .SetShapeFn(TransposeShapeFn);
   1346 
   1347 // --------------------------------------------------------------------------
   1348 REGISTER_OP("Unique")
   1349     .Input("x: T")
   1350     .Output("y: T")
   1351     .Output("idx: out_idx")
   1352     .Attr("T: type")
   1353     .Attr("out_idx: {int32, int64} = DT_INT32")
   1354     .SetShapeFn([](InferenceContext* c) {
   1355       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1356       c->set_output(1, c->input(0));
   1357       // Assert that the input rank is 1.
   1358       ShapeHandle dummy;
   1359       return c->WithRank(c->input(0), 1, &dummy);
   1360     });
   1361 
   1362 REGISTER_OP("UniqueV2")
   1363     .Input("x: T")
   1364     .Input("axis: Taxis")
   1365     .Output("y: T")
   1366     .Output("idx: out_idx")
   1367     .Attr("T: type")
   1368     .Attr("Taxis: {int32,int64} = DT_INT64")
   1369     .Attr("out_idx: {int32, int64} = DT_INT32")
   1370     .SetShapeFn([](InferenceContext* c) {
   1371       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1372       c->set_output(1, c->input(0));
   1373       return Status::OK();
   1374     });
   1375 
   1376 // --------------------------------------------------------------------------
   1377 REGISTER_OP("UniqueWithCounts")
   1378     .Input("x: T")
   1379     .Output("y: T")
   1380     .Output("idx: out_idx")
   1381     .Output("count: out_idx")
   1382     .Attr("T: type")
   1383     .Attr("out_idx: {int32, int64} = DT_INT32")
   1384     .SetShapeFn([](InferenceContext* c) {
   1385       auto uniq = c->Vector(InferenceContext::kUnknownDim);
   1386       c->set_output(0, uniq);
   1387       c->set_output(1, c->input(0));
   1388       c->set_output(2, uniq);
   1389       return Status::OK();
   1390     });
   1391 
   1392 REGISTER_OP("UniqueWithCountsV2")
   1393     .Input("x: T")
   1394     .Input("axis: Taxis")
   1395     .Output("y: T")
   1396     .Output("idx: out_idx")
   1397     .Output("count: out_idx")
   1398     .Attr("T: type")
   1399     .Attr("Taxis: {int32,int64} = DT_INT64")
   1400     .Attr("out_idx: {int32, int64} = DT_INT32")
   1401     .SetShapeFn([](InferenceContext* c) {
   1402       auto uniq = c->Vector(InferenceContext::kUnknownDim);
   1403       c->set_output(0, uniq);
   1404       c->set_output(1, c->input(0));
   1405       c->set_output(2, uniq);
   1406       return Status::OK();
   1407     });
   1408 
   1409 namespace {
   1410 
   1411 Status ShapeShapeFn(InferenceContext* c) {
   1412   for (int i = 0; i < c->num_inputs(); ++i) {
   1413     DimensionHandle dim;
   1414     if (c->RankKnown(c->input(i))) {
   1415       dim = c->MakeDim(c->Rank(c->input(i)));
   1416     } else {
   1417       dim = c->UnknownDim();
   1418     }
   1419     c->set_output(i, c->Vector(dim));
   1420   }
   1421   return Status::OK();
   1422 }
   1423 
   1424 }  // namespace
   1425 
   1426 // --------------------------------------------------------------------------
   1427 REGISTER_OP("Shape")
   1428     .Input("input: T")
   1429     .Output("output: out_type")
   1430     .Attr("T: type")
   1431     .Attr("out_type: {int32, int64} = DT_INT32")
   1432     .SetShapeFn(ShapeShapeFn);
   1433 
   1434 REGISTER_OP("ShapeN")
   1435     .Input("input: N * T")
   1436     .Output("output: N * out_type")
   1437     .Attr("N: int")
   1438     .Attr("T: type")
   1439     .Attr("out_type: {int32, int64} = DT_INT32")
   1440     .SetShapeFn(ShapeShapeFn);
   1441 
   1442 REGISTER_OP("EnsureShape")
   1443     .Input("input: T")
   1444     .Output("output: T")
   1445     .Attr("shape: shape")
   1446     .Attr("T: type")
   1447     .SetShapeFn([](InferenceContext* c) {
   1448       // Merges desired shape and statically known shape of input
   1449       PartialTensorShape desired_shape;
   1450       TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
   1451 
   1452       int rank = desired_shape.dims();
   1453       ShapeHandle input_shape_handle;
   1454       ShapeHandle desired_shape_handle;
   1455       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
   1456       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
   1457           desired_shape, &desired_shape_handle));
   1458 
   1459       ShapeHandle merged_shape;
   1460       TF_RETURN_IF_ERROR(
   1461           c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
   1462       c->set_output(0, merged_shape);
   1463       return Status::OK();
   1464     });
   1465 
   1466 // --------------------------------------------------------------------------
   1467 REGISTER_OP("ReverseSequence")
   1468     .Input("input: T")
   1469     .Input("seq_lengths: Tlen")
   1470     .Output("output: T")
   1471     .Attr("seq_dim: int")
   1472     .Attr("batch_dim: int = 0")
   1473     .Attr("T: type")
   1474     .Attr("Tlen: {int32, int64} = DT_INT64")
   1475     .SetShapeFn([](InferenceContext* c) {
   1476       ShapeHandle input = c->input(0);
   1477       ShapeHandle seq_lens_shape;
   1478       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
   1479 
   1480       int64 seq_dim;
   1481       TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
   1482       int64 batch_dim;
   1483       TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
   1484 
   1485       if (!c->RankKnown(input)) {
   1486         return shape_inference::UnknownShape(c);
   1487       }
   1488 
   1489       // Validate batch_dim and seq_dim against input.
   1490       const int32 input_rank = c->Rank(input);
   1491       if (batch_dim >= input_rank) {
   1492         return errors::InvalidArgument(
   1493             "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
   1494       }
   1495       if (seq_dim >= input_rank) {
   1496         return errors::InvalidArgument(
   1497             "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
   1498       }
   1499 
   1500       DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
   1501       TF_RETURN_IF_ERROR(
   1502           c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
   1503 
   1504       // Replace batch_dim of input with batch_size
   1505       ShapeHandle output_shape;
   1506       TF_RETURN_IF_ERROR(
   1507           c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
   1508       c->set_output(0, output_shape);
   1509       return Status::OK();
   1510     });
   1511 
   1512 // --------------------------------------------------------------------------
   1513 REGISTER_OP("Rank")
   1514     .Input("input: T")
   1515     .Output("output: int32")
   1516     .Attr("T: type")
   1517     .SetShapeFn(shape_inference::ScalarShape);
   1518 
   1519 // --------------------------------------------------------------------------
   1520 REGISTER_OP("Size")
   1521     .Input("input: T")
   1522     .Output("output: out_type")
   1523     .Attr("T: type")
   1524     .Attr("out_type: {int32, int64} = DT_INT32")
   1525     .SetShapeFn(shape_inference::ScalarShape);
   1526 
   1527 // --------------------------------------------------------------------------
   1528 REGISTER_OP("Slice")
   1529     .Input("input: T")
   1530     .Input("begin: Index")
   1531     .Input("size: Index")
   1532     .Output("output: T")
   1533     .Attr("T: type")
   1534     .Attr("Index: {int32,int64}")
   1535     .SetShapeFn(shape_inference::SliceShape);
   1536 
   1537 #ifdef INTEL_MKL
   1538 REGISTER_OP("_MklSlice")
   1539     .Input("input: T")
   1540     .Input("begin: Index")
   1541     .Input("size: Index")
   1542     .Input("mkl_input: uint8")
   1543     .Input("mkl_begin: uint8")
   1544     .Input("mkl_size: uint8")
   1545     .Output("output: T")
   1546     .Output("mkl_output: uint8")
   1547     .Attr("T: type")
   1548     .Attr("Index: {int32,int64}")
   1549     .SetShapeFn(shape_inference::SliceShape);
   1550 #endif
   1551 
   1552 REGISTER_OP("StridedSlice")
   1553     .Input("input: T")
   1554     .Input("begin: Index")
   1555     .Input("end: Index")
   1556     .Input("strides: Index")
   1557     .Output("output: T")
   1558     .Attr("T: type")
   1559     .Attr("Index: {int32, int64}")
   1560     .Attr("begin_mask: int = 0")
   1561     .Attr("end_mask: int = 0")
   1562     .Attr("ellipsis_mask: int = 0")
   1563     .Attr("new_axis_mask: int = 0")
   1564     .Attr("shrink_axis_mask: int = 0")
   1565     .SetShapeFn([](InferenceContext* c) {
   1566       ShapeHandle input = c->input(0);
   1567       ShapeHandle begin_shape, end_shape, strides_shape;
   1568       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
   1569       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
   1570       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
   1571       TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
   1572       TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
   1573       DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
   1574 
   1575       const Tensor* strides_value = c->input_tensor(3);
   1576       // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
   1577       // more shape inference here (e.g. for x[3, ::T]).
   1578       if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
   1579           strides_value == nullptr) {
   1580         c->set_output(0, c->UnknownShape());
   1581         return Status::OK();
   1582       }
   1583 
   1584       PartialTensorShape input_shape({});
   1585       for (int i = 0; i < c->Rank(input); ++i) {
   1586         auto dim = c->Dim(input, i);
   1587         input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
   1588       }
   1589 
   1590       int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask,
   1591           shrink_axis_mask;
   1592       TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
   1593       TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
   1594       TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
   1595       TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
   1596       TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
   1597 
   1598       const Tensor* begin_value = c->input_tensor(1);
   1599       const Tensor* end_value = c->input_tensor(2);
   1600 
   1601       PartialTensorShape processing_shape, final_shape;
   1602       bool is_identity, is_simple_slice, slice_dim0;
   1603       gtl::InlinedVector<int64, 4> begin, end, strides;
   1604       TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
   1605           begin_value, end_value, *strides_value, input_shape, begin_mask,
   1606           end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
   1607           &processing_shape, &final_shape, &is_identity, &is_simple_slice,
   1608           &slice_dim0, &begin, &end, &strides));
   1609 
   1610       ShapeHandle out;
   1611       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
   1612       c->set_output(0, out);
   1613 
   1614       auto* shape_and_type = c->input_handle_shapes_and_types(0);
   1615       if (shape_and_type) {
   1616         c->set_output_handle_shapes_and_types(0, *shape_and_type);
   1617       }
   1618 
   1619       return Status::OK();
   1620     });
   1621 
   1622 REGISTER_OP("StridedSliceGrad")
   1623     .Input("shape: Index")
   1624     .Input("begin: Index")
   1625     .Input("end: Index")
   1626     .Input("strides: Index")
   1627     .Input("dy: T")
   1628     .Output("output: T")
   1629     .Attr("T: type")
   1630     .Attr("Index: {int32, int64}")
   1631     .Attr("begin_mask: int = 0")
   1632     .Attr("end_mask: int = 0")
   1633     .Attr("ellipsis_mask: int = 0")
   1634     .Attr("new_axis_mask: int = 0")
   1635     .Attr("shrink_axis_mask: int = 0")
   1636     .SetShapeFn([](InferenceContext* c) {
   1637       ShapeHandle out;
   1638       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
   1639       c->set_output(0, out);
   1640       return Status::OK();
   1641     });
   1642 
   1643 REGISTER_OP("StridedSliceAssign")
   1644     .Input("ref: Ref(T)")
   1645     .Input("begin: Index")
   1646     .Input("end: Index")
   1647     .Input("strides: Index")
   1648     .Input("value: T")
   1649     .Output("output_ref: Ref(T)")
   1650     .Attr("T: type")
   1651     .Attr("Index: {int32, int64}")
   1652     .Attr("begin_mask: int = 0")
   1653     .Attr("end_mask: int = 0")
   1654     .Attr("ellipsis_mask: int = 0")
   1655     .Attr("new_axis_mask: int = 0")
   1656     .Attr("shrink_axis_mask: int = 0")
   1657     .SetShapeFn(shape_inference::UnchangedShape);
   1658 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
   1659 // broadcasting.
   1660 // --------------------------------------------------------------------------
   1661 
   1662 REGISTER_OP("ResourceStridedSliceAssign")
   1663     .Input("ref: resource")
   1664     .Input("begin: Index")
   1665     .Input("end: Index")
   1666     .Input("strides: Index")
   1667     .Input("value: T")
   1668     .Attr("T: type")
   1669     .Attr("Index: {int32, int64}")
   1670     .Attr("begin_mask: int = 0")
   1671     .Attr("end_mask: int = 0")
   1672     .Attr("ellipsis_mask: int = 0")
   1673     .Attr("new_axis_mask: int = 0")
   1674     .Attr("shrink_axis_mask: int = 0")
   1675     .SetShapeFn(shape_inference::NoOutputs);
   1676 
   1677 REGISTER_OP("Tile")
   1678     .Input("input: T")
   1679     .Input("multiples: Tmultiples")
   1680     .Output("output: T")
   1681     .Attr("T: type")
   1682     .Attr("Tmultiples: {int32, int64} = DT_INT32")
   1683     .SetShapeFn([](InferenceContext* c) {
   1684       ShapeHandle input = c->input(0);
   1685       // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
   1686       // it is a vector of non-negative integers, and (ii) doing so allows
   1687       // us to handle partially-known multiples.
   1688       ShapeHandle multiples;
   1689       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
   1690       if (c->RankKnown(input)) {
   1691         TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
   1692         ShapeHandle dummy;
   1693         TF_RETURN_IF_ERROR(
   1694             c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
   1695       }
   1696 
   1697       if (!c->RankKnown(multiples)) {
   1698         return shape_inference::UnknownShape(c);
   1699       }
   1700 
   1701       int32 rank = c->Rank(multiples);
   1702       TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
   1703       std::vector<DimensionHandle> dims(rank);
   1704       for (int i = 0; i < rank; ++i) {
   1705         TF_RETURN_IF_ERROR(
   1706             c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
   1707       }
   1708       c->set_output(0, c->MakeShape(dims));
   1709       return Status::OK();
   1710     });
   1711 
   1712 // --------------------------------------------------------------------------
   1713 REGISTER_OP("TileGrad")
   1714     .Input("input: T")
   1715     .Input("multiples: int32")
   1716     .Output("output: T")
   1717     .Attr("T: type")
   1718     .Deprecated(3, "TileGrad has been replaced with reduce_sum")
   1719     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
   1720 
   1721 // --------------------------------------------------------------------------
   1722 REGISTER_OP("Where")
   1723     .Input("input: T")
   1724     .Attr("T: {numbertype, bool} = DT_BOOL")
   1725     .Output("index: int64")
   1726     .SetShapeFn([](InferenceContext* c) {
   1727       c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
   1728       return Status::OK();
   1729     });
   1730 
   1731 // --------------------------------------------------------------------------
   1732 REGISTER_OP("BroadcastArgs")
   1733     .Input("s0: T")
   1734     .Input("s1: T")
   1735     .Output("r0: T")
   1736     .Attr("T: {int32, int64} = DT_INT32")
   1737     .SetShapeFn([](InferenceContext* c) {
   1738       ShapeHandle unused;
   1739       ShapeHandle shape_x = c->input(0);
   1740       ShapeHandle shape_y = c->input(1);
   1741       TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
   1742       TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
   1743 
   1744       if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
   1745           !c->ValueKnown(c->Dim(shape_y, 0))) {
   1746         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1747         return Status::OK();
   1748       }
   1749 
   1750       int64 x_dim = c->Value(c->Dim(shape_x, 0));
   1751       int64 y_dim = c->Value(c->Dim(shape_y, 0));
   1752 
   1753       // Broadcasted shape is going to be as large as the largest dimension.
   1754       c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
   1755       return Status::OK();
   1756     });
   1757 
   1758 // --------------------------------------------------------------------------
   1759 REGISTER_OP("BroadcastGradientArgs")
   1760     .Input("s0: T")
   1761     .Input("s1: T")
   1762     .Output("r0: T")
   1763     .Output("r1: T")
   1764     .Attr("T: {int32, int64} = DT_INT32")
   1765     .SetShapeFn([](InferenceContext* c) {
   1766       // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
   1767       ShapeHandle unused;
   1768       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
   1769       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
   1770       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1771       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
   1772       return Status::OK();
   1773     });
   1774 
   1775 // --------------------------------------------------------------------------
   1776 REGISTER_OP("Pad")
   1777     .Input("input: T")
   1778     .Input("paddings: Tpaddings")
   1779     .Output("output: T")
   1780     .Attr("T: type")
   1781     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   1782     .SetShapeFn(PadShapeFn);
   1783 
   1784 // --------------------------------------------------------------------------
   1785 REGISTER_OP("PadV2")
   1786     .Input("input: T")
   1787     .Input("paddings: Tpaddings")
   1788     .Input("constant_values: T")
   1789     .Output("output: T")
   1790     .Attr("T: type")
   1791     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   1792     .SetShapeFn(PadShapeFn);
   1793 
   1794 // --------------------------------------------------------------------------
   1795 REGISTER_OP("MirrorPad")
   1796     .Input("input: T")
   1797     .Input("paddings: Tpaddings")
   1798     .Output("output: T")
   1799     .Attr("T: type")
   1800     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   1801     .Attr(GetMirrorPadModeAttrString())
   1802     .SetShapeFn(PadShapeFn);
   1803 
   1804 // --------------------------------------------------------------------------
   1805 namespace {
   1806 template <typename T>
   1807 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
   1808                       const Tensor* paddings_t, int64 input_rank) {
   1809   auto paddings_data = paddings_t->matrix<T>();
   1810   std::vector<DimensionHandle> dims(input_rank);
   1811   for (int64 i = 0; i < input_rank; ++i) {
   1812     const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
   1813     const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
   1814     if (pad0 < 0 || pad1 < 0) {
   1815       return errors::InvalidArgument("Paddings must be non-negative");
   1816     }
   1817 
   1818     TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
   1819   }
   1820   c->set_output(0, c->MakeShape(dims));
   1821   return Status::OK();
   1822 }
   1823 
   1824 }  // namespace
   1825 
   1826 REGISTER_OP("MirrorPadGrad")
   1827     .Input("input: T")
   1828     .Input("paddings: Tpaddings")
   1829     .Output("output: T")
   1830     .Attr("T: type")
   1831     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   1832     .Attr(GetMirrorPadModeAttrString())
   1833     .SetShapeFn([](InferenceContext* c) {
   1834       ShapeHandle paddings;
   1835       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
   1836       DimensionHandle pad_0 = c->Dim(paddings, 0);
   1837       if (!c->ValueKnown(pad_0)) {
   1838         // We don't know the rank of the output since the first
   1839         // padding dimension is unknown.
   1840         c->set_output(0, c->UnknownShape());
   1841         return Status::OK();
   1842       }
   1843 
   1844       int64 input_rank = c->Value(pad_0);
   1845       ShapeHandle input;
   1846       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
   1847       TF_RETURN_IF_ERROR(
   1848           c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
   1849 
   1850       const Tensor* paddings_t = c->input_tensor(1);
   1851       if (paddings_t == nullptr) {
   1852         // Values of 'paddings' is not available, but we know the
   1853         // input rank, so return the rank of the output with unknown
   1854         // dimensions.
   1855         c->set_output(0, c->UnknownShapeOfRank(input_rank));
   1856         return Status::OK();
   1857       }
   1858 
   1859       if (paddings_t->dtype() == DT_INT32) {
   1860         return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
   1861       } else {
   1862         return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
   1863       }
   1864     });
   1865 
   1866 // --------------------------------------------------------------------------
   1867 REGISTER_OP("Placeholder")
   1868     .Output("output: dtype")
   1869     .Attr("dtype: type")
   1870     .Attr("shape: shape = { unknown_rank: true }")
   1871     .SetShapeFn([](InferenceContext* c) {
   1872       PartialTensorShape shape;
   1873       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
   1874 
   1875       // Placeholder has legacy behavior where we cannot tell the difference
   1876       // between a scalar shape attribute and 'unknown shape'.  So if the shape
   1877       // is a scalar, we return an unknown shape.
   1878       if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
   1879         return shape_inference::UnknownShape(c);
   1880       }
   1881 
   1882       ShapeHandle out;
   1883       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
   1884       c->set_output(0, out);
   1885       return Status::OK();
   1886     });
   1887 
   1888 // Placeholder was modified in a backwards compatible way to do what
   1889 // PlaceholderV2 did, so we have deprecated V2 (no one was really
   1890 // using it).
   1891 REGISTER_OP("PlaceholderV2")
   1892     .Output("output: dtype")
   1893     .Attr("dtype: type")
   1894     .Attr("shape: shape")
   1895     .SetShapeFn(shape_inference::ExplicitShape)
   1896     .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
   1897 
   1898 // --------------------------------------------------------------------------
   1899 REGISTER_OP("PlaceholderWithDefault")
   1900     .Input("input: dtype")
   1901     .Output("output: dtype")
   1902     .Attr("dtype: type")
   1903     .Attr("shape: shape")
   1904     .SetShapeFn([](InferenceContext* c) {
   1905       ShapeHandle input = c->input(0);
   1906       PartialTensorShape shape;
   1907       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
   1908       ShapeHandle out;
   1909       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
   1910 
   1911       // We merge for compatibility checking, but return the output,
   1912       // since output_shape may be less precise than input_shape.
   1913       ShapeHandle unused;
   1914       TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
   1915       c->set_output(0, out);
   1916       return Status::OK();
   1917     });
   1918 
   1919 // --------------------------------------------------------------------------
   1920 REGISTER_OP("ExpandDims")
   1921     .Input("input: T")
   1922     .Input("dim: Tdim")
   1923     .Output("output: T")
   1924     .Attr("T: type")
   1925     .Attr("Tdim: {int32, int64} = DT_INT32")
   1926     .SetShapeFn([](InferenceContext* c) {
   1927       ShapeHandle input = c->input(0);
   1928 
   1929       const Tensor* dim_t = c->input_tensor(1);
   1930       if (dim_t != nullptr && dim_t->NumElements() != 1) {
   1931         return errors::InvalidArgument(
   1932             "'dim' input must be a tensor with a single value");
   1933       }
   1934       if (dim_t == nullptr || !c->RankKnown(input)) {
   1935         c->set_output(0, c->UnknownShape());
   1936         return Status::OK();
   1937       }
   1938 
   1939       int64 dim;
   1940       if (dim_t->dtype() == DT_INT32) {
   1941         dim = static_cast<int64>(dim_t->flat<int32>()(0));
   1942       } else {
   1943         dim = dim_t->flat<int64>()(0);
   1944       }
   1945 
   1946       const int32 rank = c->Rank(input);
   1947       const int32 min_dim = -1 * rank - 1;
   1948       if (dim < min_dim || dim > rank) {
   1949         return errors::InvalidArgument("dim ", dim, " not in the interval [",
   1950                                        min_dim, ", ", rank, "].");
   1951       }
   1952 
   1953       if (dim < 0) {
   1954         dim += rank + 1;
   1955       }
   1956 
   1957       ShapeHandle end;
   1958       TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
   1959 
   1960       // Build output as start + 1 + end.
   1961       ShapeHandle output;
   1962       TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
   1963       TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
   1964       TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
   1965       c->set_output(0, output);
   1966       return Status::OK();
   1967     });
   1968 
   1969 // --------------------------------------------------------------------------
   1970 REGISTER_OP("Squeeze")
   1971     .Input("input: T")
   1972     .Output("output: T")
   1973     .Attr("T: type")
   1974     .Attr("squeeze_dims: list(int) >= 0 = []")
   1975     .SetShapeFn([](InferenceContext* c) {
   1976       ShapeHandle input = c->input(0);
   1977       if (!c->RankKnown(input)) {
   1978         // Input shape unknown.
   1979         return shape_inference::UnknownShape(c);
   1980       }
   1981 
   1982       const int32 input_rank = c->Rank(input);
   1983 
   1984       // Validate and wrap squeeze dimensions.
   1985       std::vector<int32> squeeze_dims;
   1986       TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
   1987       for (int i = 0; i < squeeze_dims.size(); ++i) {
   1988         if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
   1989           return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
   1990                                          -input_rank, ",", input_rank, ").");
   1991         }
   1992 
   1993         if (squeeze_dims[i] < 0) {
   1994           squeeze_dims[i] += input_rank;
   1995         }
   1996       }
   1997 
   1998       std::vector<DimensionHandle> result_shape;
   1999       for (int i = 0; i < input_rank; ++i) {
   2000         // True if squeeze_dims contains an entry to squeeze this
   2001         // dimension.
   2002         bool is_explicit_match =
   2003             std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
   2004             squeeze_dims.end();
   2005 
   2006         DimensionHandle dim = c->Dim(input, i);
   2007 
   2008         if (!c->ValueKnown(dim)) {
   2009           // Assume that the squeezed dimension will be 1 at runtime.
   2010           if (is_explicit_match) continue;
   2011 
   2012           // If squeezing all 1 dimensions, and we see an unknown value,
   2013           // give up and return Unknown Shape.
   2014           if (squeeze_dims.empty()) {
   2015             c->set_output(0, c->UnknownShape());
   2016             return Status::OK();
   2017           }
   2018         } else if (c->Value(dim) == 1) {
   2019           if (is_explicit_match || squeeze_dims.empty()) {
   2020             // If explicitly squeezing, or squeezing all 1s, remove
   2021             // this dimension.
   2022             continue;
   2023           }
   2024         } else if (is_explicit_match) {
   2025           return errors::InvalidArgument("Can not squeeze dim[", i,
   2026                                          "], expected a dimension of 1, got ",
   2027                                          c->Value(c->Dim(input, i)));
   2028         }
   2029 
   2030         result_shape.emplace_back(dim);
   2031       }
   2032 
   2033       c->set_output(0, c->MakeShape(result_shape));
   2034       return Status::OK();
   2035     });
   2036 
   2037 // --------------------------------------------------------------------------
   2038 REGISTER_OP("ListDiff")
   2039     .Input("x: T")
   2040     .Input("y: T")
   2041     .Output("out: T")
   2042     .Output("idx: out_idx")
   2043     .Attr("T: type")
   2044     .Attr("out_idx: {int32, int64} = DT_INT32")
   2045     .SetShapeFn([](InferenceContext* c) {
   2046       ShapeHandle unused;
   2047       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
   2048       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
   2049       // TODO(mrry): Indicate that the length falls within an interval?
   2050       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
   2051       c->set_output(0, out);
   2052       c->set_output(1, out);
   2053       return Status::OK();
   2054     });
   2055 
   2056 namespace {
   2057 
   2058 // Converts Tensor to flat std::vector<int64>.
   2059 template <typename InputType>
   2060 std::vector<int64> GetFlatInt64(const Tensor& t) {
   2061   std::vector<int64> output(t.shape().num_elements());
   2062   auto eigen_vec = t.flat<InputType>();
   2063   std::copy_n(&eigen_vec(0), output.size(), output.begin());
   2064   return output;
   2065 }
   2066 
   2067 // Converts int32 or int64 Tensor to flat std::vector<int64>.
   2068 std::vector<int64> GetFlatInt64(const Tensor& t) {
   2069   if (t.dtype() == DT_INT32) {
   2070     return GetFlatInt64<int32>(t);
   2071   } else {
   2072     return GetFlatInt64<int64>(t);
   2073   }
   2074 }
   2075 
   2076 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
   2077                                ShapeHandle block_shape_shape,
   2078                                const Tensor* block_shape_t,
   2079                                ShapeHandle paddings_shape,
   2080                                const Tensor* paddings_t) {
   2081   if (c->Rank(block_shape_shape) != 1) {
   2082     return errors::InvalidArgument("block_shape must have rank 1.");
   2083   }
   2084 
   2085   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
   2086   if (!c->ValueKnown(num_block_dims_handle)) {
   2087     return errors::InvalidArgument("block_shape must have known size.");
   2088   }
   2089 
   2090   const int64 num_block_dims = c->Value(num_block_dims_handle);
   2091 
   2092   TF_RETURN_IF_ERROR(
   2093       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
   2094 
   2095   TF_RETURN_IF_ERROR(
   2096       c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
   2097 
   2098   DimensionHandle batch_size = c->Dim(input_shape, 0);
   2099   std::vector<int64> block_shape_vec;
   2100   if (block_shape_t) {
   2101     block_shape_vec = GetFlatInt64(*block_shape_t);
   2102     for (int64 dim = 0; dim < num_block_dims; ++dim) {
   2103       const int64 block_shape_value = block_shape_vec[dim];
   2104       if (block_shape_value < 1) {
   2105         return errors::InvalidArgument("block_shape must be positive");
   2106       }
   2107       if (c->ValueKnown(batch_size)) {
   2108         TF_RETURN_IF_ERROR(
   2109             c->Multiply(batch_size, block_shape_value, &batch_size));
   2110       } else {
   2111         batch_size = c->UnknownDim();
   2112       }
   2113     }
   2114   } else if (num_block_dims > 0) {
   2115     batch_size = c->UnknownDim();
   2116   }
   2117 
   2118   std::vector<DimensionHandle> output_dims{batch_size};
   2119   output_dims.resize(num_block_dims + 1, c->UnknownDim());
   2120 
   2121   if (paddings_t) {
   2122     const std::vector<int64> paddings_vec = GetFlatInt64(*paddings_t);
   2123     for (int64 dim = 0; dim < num_block_dims; ++dim) {
   2124       const int64 pad_start = paddings_vec[dim * 2],
   2125                   pad_end = paddings_vec[dim * 2 + 1];
   2126       if (pad_start < 0 || pad_end < 0) {
   2127         return errors::InvalidArgument("paddings cannot be negative");
   2128       }
   2129       if (block_shape_t) {
   2130         DimensionHandle padded_size;
   2131         TF_RETURN_IF_ERROR(
   2132             c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
   2133         TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
   2134         TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
   2135                                      /*evenly_divisible=*/true,
   2136                                      &output_dims[dim + 1]));
   2137       }
   2138     }
   2139   }
   2140 
   2141   ShapeHandle remaining_input_shape;
   2142   TF_RETURN_IF_ERROR(
   2143       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
   2144 
   2145   ShapeHandle result;
   2146   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
   2147                                     remaining_input_shape, &result));
   2148   c->set_output(0, result);
   2149   return Status::OK();
   2150 }
   2151 
   2152 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
   2153                                ShapeHandle block_shape_shape,
   2154                                const Tensor* block_shape_t,
   2155                                ShapeHandle crops_shape, const Tensor* crops_t) {
   2156   if (c->Rank(block_shape_shape) != 1) {
   2157     return errors::InvalidArgument("block_shape must have rank 1.");
   2158   }
   2159 
   2160   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
   2161   if (!c->ValueKnown(num_block_dims_handle)) {
   2162     return errors::InvalidArgument("block_shape must have known size.");
   2163   }
   2164 
   2165   const int64 num_block_dims = c->Value(num_block_dims_handle);
   2166 
   2167   TF_RETURN_IF_ERROR(
   2168       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
   2169 
   2170   TF_RETURN_IF_ERROR(
   2171       c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
   2172 
   2173   DimensionHandle batch_size = c->Dim(input_shape, 0);
   2174   std::vector<int64> block_shape_vec;
   2175   if (block_shape_t) {
   2176     block_shape_vec = GetFlatInt64(*block_shape_t);
   2177     for (int64 dim = 0; dim < num_block_dims; ++dim) {
   2178       const int64 block_shape_value = block_shape_vec[dim];
   2179       if (block_shape_value < 1) {
   2180         return errors::InvalidArgument("block_shape must be positive");
   2181       }
   2182       if (c->ValueKnown(batch_size)) {
   2183         TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
   2184                                      /*evenly_divisible=*/true, &batch_size));
   2185       } else {
   2186         batch_size = c->UnknownDim();
   2187       }
   2188     }
   2189   } else if (num_block_dims > 0) {
   2190     batch_size = c->UnknownDim();
   2191   }
   2192 
   2193   std::vector<DimensionHandle> output_dims{batch_size};
   2194   output_dims.resize(num_block_dims + 1, c->UnknownDim());
   2195 
   2196   if (crops_t) {
   2197     const std::vector<int64> crops_vec = GetFlatInt64(*crops_t);
   2198     for (int64 dim = 0; dim < num_block_dims; ++dim) {
   2199       const int64 crop_start = crops_vec[dim * 2],
   2200                   crop_end = crops_vec[dim * 2 + 1];
   2201       if (crop_start < 0 || crop_end < 0) {
   2202         return errors::InvalidArgument("crops cannot be negative");
   2203       }
   2204       if (block_shape_t) {
   2205         DimensionHandle cropped_size;
   2206         TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
   2207                                        block_shape_vec[dim], &cropped_size));
   2208         TF_RETURN_IF_ERROR(
   2209             c->Subtract(cropped_size, crop_start, &cropped_size));
   2210         TF_RETURN_IF_ERROR(
   2211             c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
   2212       }
   2213     }
   2214   }
   2215 
   2216   ShapeHandle remaining_input_shape;
   2217   TF_RETURN_IF_ERROR(
   2218       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
   2219 
   2220   ShapeHandle result;
   2221   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
   2222                                     remaining_input_shape, &result));
   2223   c->set_output(0, result);
   2224   return Status::OK();
   2225 }
   2226 
   2227 }  // namespace
   2228 
   2229 // --------------------------------------------------------------------------
   2230 REGISTER_OP("SpaceToBatchND")
   2231     .Input("input: T")
   2232     .Input("block_shape: Tblock_shape")
   2233     .Input("paddings: Tpaddings")
   2234     .Output("output: T")
   2235     .Attr("T: type")
   2236     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
   2237     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   2238     .SetShapeFn([](InferenceContext* c) {
   2239       return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
   2240                                      c->input_tensor(1), c->input(2),
   2241                                      c->input_tensor(2));
   2242     });
   2243 
   2244 // --------------------------------------------------------------------------
   2245 REGISTER_OP("SpaceToBatch")
   2246     .Input("input: T")
   2247     .Input("paddings: Tpaddings")
   2248     .Output("output: T")
   2249     .Attr("T: type")
   2250     .Attr("Tpaddings: {int32, int64} = DT_INT32")
   2251     .Attr("block_size: int >= 2")
   2252     .SetShapeFn([](InferenceContext* c) {
   2253       ShapeHandle input_shape;
   2254       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   2255 
   2256       int32 block_size;
   2257       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
   2258 
   2259       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
   2260       auto block_shape_vec = block_shape.vec<int64>();
   2261       block_shape_vec(0) = block_size;
   2262       block_shape_vec(1) = block_size;
   2263 
   2264       return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
   2265                                      &block_shape, c->input(1),
   2266                                      c->input_tensor(1));
   2267     });
   2268 
   2269 // --------------------------------------------------------------------------
   2270 REGISTER_OP("BatchToSpaceND")
   2271     .Input("input: T")
   2272     .Input("block_shape: Tblock_shape")
   2273     .Input("crops: Tcrops")
   2274     .Output("output: T")
   2275     .Attr("T: type")
   2276     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
   2277     .Attr("Tcrops: {int32, int64} = DT_INT32")
   2278     .SetShapeFn([](InferenceContext* c) {
   2279       return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
   2280                                      c->input_tensor(1), c->input(2),
   2281                                      c->input_tensor(2));
   2282     });
   2283 
   2284 // --------------------------------------------------------------------------
   2285 REGISTER_OP("BatchToSpace")
   2286     .Input("input: T")
   2287     .Input("crops: Tidx")
   2288     .Output("output: T")
   2289     .Attr("T: type")
   2290     .Attr("block_size: int >= 2")
   2291     .Attr("Tidx: {int32, int64} = DT_INT32")
   2292     .SetShapeFn([](InferenceContext* c) {
   2293       ShapeHandle input_shape;
   2294       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   2295 
   2296       int32 block_size;
   2297       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
   2298 
   2299       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
   2300       auto block_shape_vec = block_shape.vec<int64>();
   2301       block_shape_vec(0) = block_size;
   2302       block_shape_vec(1) = block_size;
   2303 
   2304       return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
   2305                                      &block_shape, c->input(1),
   2306                                      c->input_tensor(1));
   2307     });
   2308 
   2309 // --------------------------------------------------------------------------
   2310 REGISTER_OP("SpaceToDepth")
   2311     .Input("input: T")
   2312     .Output("output: T")
   2313     .Attr("T: type")
   2314     .Attr("block_size: int >= 2")
   2315     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
   2316     // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
   2317     .SetShapeFn([](InferenceContext* c) {
   2318       string data_format_str;
   2319       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
   2320       TensorFormat data_format;
   2321       FormatFromString(data_format_str, &data_format);
   2322 
   2323       constexpr int num_spatial_dims = 2;
   2324       const int dims =
   2325           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
   2326       ShapeHandle input;
   2327       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
   2328 
   2329       int32 block_size;
   2330       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
   2331 
   2332       DimensionHandle batch_size =
   2333           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
   2334       DimensionHandle input_height =
   2335           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
   2336       DimensionHandle input_width =
   2337           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
   2338       DimensionHandle input_depth =
   2339           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
   2340 
   2341       DimensionHandle output_height;
   2342       DimensionHandle output_width;
   2343       DimensionHandle output_depth;
   2344       // Will return an error if input height or width are not evenly divisible.
   2345       TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
   2346                                    true /* evenly_divisible */,
   2347                                    &output_height));
   2348       TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
   2349                                    true /* evenly_divisible */, &output_width));
   2350 
   2351       TF_RETURN_IF_ERROR(
   2352           c->Multiply(input_depth, block_size * block_size, &output_depth));
   2353 
   2354       ShapeHandle output_shape;
   2355       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
   2356                                              {output_height, output_width},
   2357                                              output_depth, &output_shape, c));
   2358 
   2359       c->set_output(0, output_shape);
   2360       return Status::OK();
   2361     });
   2362 
   2363 // --------------------------------------------------------------------------
   2364 REGISTER_OP("DepthToSpace")
   2365     .Input("input: T")
   2366     .Output("output: T")
   2367     .Attr("T: type")
   2368     .Attr("block_size: int >= 2")
   2369     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
   2370     // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
   2371     .SetShapeFn([](InferenceContext* c) {
   2372       string data_format_str;
   2373       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
   2374       TensorFormat data_format;
   2375       FormatFromString(data_format_str, &data_format);
   2376 
   2377       constexpr int num_spatial_dims = 2;
   2378       const int dims =
   2379           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
   2380 
   2381       ShapeHandle input;
   2382       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
   2383 
   2384       int32 block_size;
   2385       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
   2386 
   2387       DimensionHandle batch_size =
   2388           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
   2389       DimensionHandle input_height =
   2390           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
   2391       DimensionHandle input_width =
   2392           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
   2393       DimensionHandle input_depth =
   2394           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
   2395 
   2396       DimensionHandle output_height;
   2397       DimensionHandle output_width;
   2398       DimensionHandle output_depth;
   2399       TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
   2400       TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
   2401 
   2402       // Will return an error if input_depth is not evenly divisible.
   2403       TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
   2404                                    true /* evenly_divisible */, &output_depth));
   2405 
   2406       ShapeHandle output_shape;
   2407       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
   2408                                              {output_height, output_width},
   2409                                              output_depth, &output_shape, c));
   2410 
   2411       c->set_output(0, output_shape);
   2412       return Status::OK();
   2413     });
   2414 
   2415 // --------------------------------------------------------------------------
   2416 
   2417 REGISTER_OP("ExtractImagePatches")
   2418     .Input("images: T")
   2419     .Output("patches: T")
   2420     .Attr("ksizes: list(int) >= 4")
   2421     .Attr("strides: list(int) >= 4")
   2422     .Attr("rates: list(int) >= 4")
   2423     .Attr("T: realnumbertype")
   2424     .Attr(GetPaddingAttrString())
   2425     .SetShapeFn([](InferenceContext* c) {
   2426       ShapeHandle input_shape;
   2427       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   2428 
   2429       std::vector<int32> ksizes;
   2430       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
   2431       if (ksizes.size() != 4) {
   2432         return errors::InvalidArgument(
   2433             "ExtractImagePatches requires the ksizes attribute to contain 4 "
   2434             "values, but got: ",
   2435             ksizes.size());
   2436       }
   2437 
   2438       std::vector<int32> strides;
   2439       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
   2440       if (strides.size() != 4) {
   2441         return errors::InvalidArgument(
   2442             "ExtractImagePatches requires the stride attribute to contain 4 "
   2443             "values, but got: ",
   2444             strides.size());
   2445       }
   2446 
   2447       std::vector<int32> rates;
   2448       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
   2449       if (rates.size() != 4) {
   2450         return errors::InvalidArgument(
   2451             "ExtractImagePatches requires the rates attribute to contain 4 "
   2452             "values, but got: ",
   2453             rates.size());
   2454       }
   2455 
   2456       int32 ksize_rows = ksizes[1];
   2457       int32 ksize_cols = ksizes[2];
   2458 
   2459       int32 stride_rows = strides[1];
   2460       int32 stride_cols = strides[2];
   2461 
   2462       int32 rate_rows = rates[1];
   2463       int32 rate_cols = rates[2];
   2464 
   2465       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
   2466       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
   2467 
   2468       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
   2469       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
   2470       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
   2471       DimensionHandle output_depth_dim;
   2472       TF_RETURN_IF_ERROR(c->Multiply(
   2473           c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
   2474 
   2475       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
   2476         ShapeHandle output_shape =
   2477             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
   2478                           InferenceContext::kUnknownDim, output_depth_dim});
   2479         c->set_output(0, output_shape);
   2480         return Status::OK();
   2481       }
   2482       auto in_rows = c->Value(in_rows_dim);
   2483       auto in_cols = c->Value(in_cols_dim);
   2484 
   2485       Padding padding;
   2486       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
   2487 
   2488       int64 output_rows, output_cols;
   2489       int64 padding_before, padding_after;
   2490       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
   2491           in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
   2492           &padding_before, &padding_after));
   2493       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
   2494           in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
   2495           &padding_before, &padding_after));
   2496       ShapeHandle output_shape = c->MakeShape(
   2497           {batch_size_dim, output_rows, output_cols, output_depth_dim});
   2498       c->set_output(0, output_shape);
   2499       return Status::OK();
   2500     });
   2501 
   2502 // --------------------------------------------------------------------------
   2503 
   2504 // To enable rates, uncomment all lines commented below and use ksize_*_eff
   2505 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
   2506 // of ksize_*.
   2507 REGISTER_OP("ExtractVolumePatches")
   2508     .Input("input: T")
   2509     .Output("patches: T")
   2510     .Attr("ksizes: list(int) >= 5")
   2511     .Attr("strides: list(int) >= 5")
   2512     /* .Attr("rates: list(int) >= 5") */
   2513     .Attr("T: realnumbertype")
   2514     .Attr(GetPaddingAttrString())
   2515     .SetShapeFn([](InferenceContext* c) {
   2516       ShapeHandle input_shape;
   2517       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
   2518 
   2519       std::vector<int32> ksizes;
   2520       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
   2521       if (ksizes.size() != 5) {
   2522         return errors::InvalidArgument(
   2523             "ExtractVolumePatches requires the ksizes attribute to contain 5 "
   2524             "values, but got: ",
   2525             ksizes.size());
   2526       }
   2527 
   2528       std::vector<int32> strides;
   2529       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
   2530       if (strides.size() != 5) {
   2531         return errors::InvalidArgument(
   2532             "ExtractVolumePatches requires the stride attribute to contain 5 "
   2533             "values, but got: ",
   2534             strides.size());
   2535       }
   2536 
   2537       /*
   2538       // TODO(hsgkim): Enable rates.
   2539       // See extract_volume_patches_op.cc for why rates are disabled now.
   2540 
   2541       std::vector<int32> rates;
   2542       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
   2543       if (rates.size() != 5) {
   2544         return errors::InvalidArgument(
   2545             "ExtractVolumePatches requires the rates attribute to contain 5 "
   2546             "values, but got: ",
   2547             rates.size());
   2548       }
   2549       */
   2550 
   2551       int32 ksize_planes = ksizes[1];
   2552       int32 ksize_rows = ksizes[2];
   2553       int32 ksize_cols = ksizes[3];
   2554 
   2555       int32 stride_planes = strides[1];
   2556       int32 stride_rows = strides[2];
   2557       int32 stride_cols = strides[3];
   2558 
   2559       /*
   2560       int32 rate_planes = rates[1];
   2561       int32 rate_rows = rates[2];
   2562       int32 rate_cols = rates[3];
   2563 
   2564       int32 ksize_planes_eff = ksize_planes +
   2565                                (ksize_planes - 1) * (rate_planes - 1);
   2566       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
   2567       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
   2568       */
   2569 
   2570       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
   2571       DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
   2572       DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
   2573       DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
   2574       DimensionHandle output_depth_dim;
   2575       TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
   2576                                      ksize_planes * ksize_rows * ksize_cols,
   2577                                      &output_depth_dim));
   2578 
   2579       if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
   2580           !c->ValueKnown(in_cols_dim)) {
   2581         ShapeHandle output_shape =
   2582             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
   2583                           InferenceContext::kUnknownDim, output_depth_dim});
   2584         c->set_output(0, output_shape);
   2585         return Status::OK();
   2586       }
   2587       auto in_planes = c->Value(in_planes_dim);
   2588       auto in_rows = c->Value(in_rows_dim);
   2589       auto in_cols = c->Value(in_cols_dim);
   2590 
   2591       Padding padding;
   2592       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
   2593 
   2594       int64 output_planes, output_rows, output_cols;
   2595       int64 padding_before, padding_after;
   2596       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
   2597           in_planes, ksize_planes, stride_planes, padding, &output_planes,
   2598           &padding_before, &padding_after));
   2599       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
   2600           in_rows, ksize_rows, stride_rows, padding, &output_rows,
   2601           &padding_before, &padding_after));
   2602       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
   2603           in_cols, ksize_cols, stride_cols, padding, &output_cols,
   2604           &padding_before, &padding_after));
   2605       ShapeHandle output_shape =
   2606           c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
   2607                         output_depth_dim});
   2608       c->set_output(0, output_shape);
   2609       return Status::OK();
   2610     });
   2611 
   2612 // --------------------------------------------------------------------------
   2613 
   2614 REGISTER_OP("Bitcast")
   2615     .Input("input: T")
   2616     .Output("output: type")
   2617     // All supported dtypes are listed here to include qint16, quint16, uint32,
   2618     // and uint64.
   2619     .Attr(
   2620         "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
   2621         "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
   2622         "qint16, quint16, qint32}")
   2623     .Attr(
   2624         "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
   2625         "uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
   2626         "qint16, quint16, qint32}")
   2627     .SetShapeFn([](InferenceContext* c) {
   2628       ShapeHandle input = c->input(0);
   2629       if (!c->RankKnown(input)) {
   2630         // Input shape unknown.
   2631         return shape_inference::UnknownShape(c);
   2632       }
   2633 
   2634       // Find the size of the input and output data types.
   2635       DataType input_type;
   2636       DataType output_type;
   2637       TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
   2638       TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
   2639       const int input_type_size = DataTypeSize(input_type);
   2640       const int output_type_size = DataTypeSize(output_type);
   2641 
   2642       if (input_type_size == 0 || output_type_size == 0) {
   2643         return errors::InvalidArgument("Cannot bitcast types ",
   2644                                        DataTypeString(input_type), " to ",
   2645                                        DataTypeString(output_type),
   2646                                        " because "
   2647                                        "one of the type sizes is zero.");
   2648       }
   2649 
   2650       ShapeHandle new_shape;
   2651       if (input_type_size == output_type_size) {
   2652         // No change in size.
   2653         new_shape = input;
   2654       } else if (input_type_size < output_type_size) {
   2655         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
   2656 
   2657         int64 divisor_val = output_type_size / input_type_size;
   2658         DimensionHandle last_dim = c->Dim(new_shape, -1);
   2659         if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
   2660           TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
   2661         } else {
   2662           return errors::InvalidArgument("Cannot bitcast due to shape. ",
   2663                                          c->Value(last_dim), " does not match ",
   2664                                          divisor_val);
   2665         }
   2666       } else {
   2667         // Input type size is larger than output type size.
   2668         int64 divisor_val = input_type_size / output_type_size;
   2669         ShapeHandle extension = c->Vector(divisor_val);
   2670         TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
   2671       }
   2672 
   2673       c->set_output(0, new_shape);
   2674       return Status::OK();
   2675     });
   2676 
   2677 REGISTER_OP("OneHot")
   2678     .Input("indices: TI")
   2679     .Input("depth: int32")
   2680     .Input("on_value: T")
   2681     .Input("off_value: T")
   2682     .Attr("axis: int = -1")
   2683     .Output("output: T")
   2684     .Attr("T: type")
   2685     .Attr("TI: {uint8, int32, int64} = DT_INT64")
   2686     .SetShapeFn([](InferenceContext* c) {
   2687       int32 axis;
   2688       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
   2689       if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
   2690 
   2691       DimensionHandle depth;
   2692       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
   2693 
   2694       ShapeHandle indices = c->input(0);
   2695       if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
   2696 
   2697       int32 new_rank = c->Rank(indices) + 1;
   2698       // We need to add new_rank to axis in the case the axis is -1 because
   2699       // C++ returns negative values from % if the dividend is negative.
   2700       int32 depth_index = (axis + new_rank) % new_rank;
   2701       // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
   2702       ShapeHandle front;
   2703       ShapeHandle back;
   2704       ShapeHandle out;
   2705       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
   2706       TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
   2707       TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
   2708       TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
   2709       c->set_output(0, out);
   2710       return Status::OK();
   2711     });
   2712 
   2713 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
   2714 REGISTER_OP("QuantizeAndDequantize")
   2715     .Input("input: T")
   2716     .Attr("signed_input: bool = true")
   2717     .Attr("num_bits: int = 8")
   2718     .Attr("range_given: bool = false")
   2719     .Attr("input_min: float = 0")
   2720     .Attr("input_max: float = 0")
   2721     .Output("output: T")
   2722     .Attr("T: {bfloat16, half, float, double}")
   2723     .SetShapeFn(shape_inference::UnchangedShape)
   2724     .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
   2725 
   2726 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
   2727 REGISTER_OP("QuantizeAndDequantizeV2")
   2728     .Input("input: T")
   2729     .Input("input_min: T")
   2730     .Input("input_max: T")
   2731     .Attr("signed_input: bool = true")
   2732     .Attr("num_bits: int = 8")
   2733     .Attr("range_given: bool = false")
   2734     .Output("output: T")
   2735     .Attr("T: {bfloat16, half, float, double}")
   2736     .Attr(
   2737         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
   2738         "'HALF_TO_EVEN'")
   2739     .SetShapeFn([](InferenceContext* c) {
   2740       ShapeHandle unused;
   2741       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   2742       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2743       c->set_output(0, c->input(0));
   2744       return Status::OK();
   2745     });
   2746 
   2747 REGISTER_OP("QuantizeAndDequantizeV3")
   2748     .Input("input: T")
   2749     .Input("input_min: T")
   2750     .Input("input_max: T")
   2751     .Input("num_bits: int32")
   2752     .Attr("signed_input: bool = true")
   2753     .Attr("range_given: bool = true")
   2754     .Output("output: T")
   2755     .Attr("T: {bfloat16, half, float, double}")
   2756     .SetShapeFn([](InferenceContext* c) {
   2757       ShapeHandle unused;
   2758       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   2759       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2760       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   2761       c->set_output(0, c->input(0));
   2762       return Status::OK();
   2763     });
   2764 
   2765 REGISTER_OP("QuantizeV2")
   2766     .Input("input: float")
   2767     .Input("min_range: float")
   2768     .Input("max_range: float")
   2769     .Output("output: T")
   2770     .Output("output_min: float")
   2771     .Output("output_max: float")
   2772     .Attr("T: quantizedtype")
   2773     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
   2774     .Attr(
   2775         "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
   2776         "'HALF_AWAY_FROM_ZERO'")
   2777     .SetShapeFn([](InferenceContext* c) {
   2778       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   2779       ShapeHandle unused;
   2780       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   2781       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2782       c->set_output(1, c->Scalar());
   2783       c->set_output(2, c->Scalar());
   2784       return Status::OK();
   2785     });
   2786 
   2787 REGISTER_OP("Dequantize")
   2788     .Input("input: T")
   2789     .Input("min_range: float")
   2790     .Input("max_range: float")
   2791     .Output("output: float")
   2792     .Attr("T: quantizedtype")
   2793     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
   2794     .SetShapeFn([](InferenceContext* c) {
   2795       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   2796       ShapeHandle unused;
   2797       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   2798       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2799       return Status::OK();
   2800     });
   2801 
   2802 REGISTER_OP("QuantizedConcat")
   2803     .Input("concat_dim: int32")
   2804     .Input("values: N * T")
   2805     .Input("input_mins: N * float32")
   2806     .Input("input_maxes: N * float32")
   2807     .Output("output: T")
   2808     .Output("output_min: float")
   2809     .Output("output_max: float")
   2810     .Attr("N: int >= 2")
   2811     .Attr("T: type")
   2812     .SetShapeFn([](InferenceContext* c) {
   2813       const int n = (c->num_inputs() - 1) / 3;
   2814       TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
   2815       ShapeHandle unused;
   2816       for (int i = n + 1; i < c->num_inputs(); ++i) {
   2817         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
   2818       }
   2819       c->set_output(1, c->Scalar());
   2820       c->set_output(2, c->Scalar());
   2821       return Status::OK();
   2822     });
   2823 
   2824 REGISTER_OP("QuantizedReshape")
   2825     .Input("tensor: T")
   2826     .Input("shape: Tshape")
   2827     .Input("input_min: float")
   2828     .Input("input_max: float")
   2829     .Output("output: T")
   2830     .Output("output_min: float")
   2831     .Output("output_max: float")
   2832     .Attr("T: type")
   2833     .Attr("Tshape: {int32, int64} = DT_INT32")
   2834     .SetShapeFn([](InferenceContext* c) {
   2835       TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
   2836       ShapeHandle unused;
   2837       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2838       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   2839       c->set_output(1, c->Scalar());
   2840       c->set_output(2, c->Scalar());
   2841       return Status::OK();
   2842     });
   2843 
   2844 REGISTER_OP("QuantizedInstanceNorm")
   2845     .Input("x: T")
   2846     .Input("x_min: float")
   2847     .Input("x_max: float")
   2848     .Output("y: T")
   2849     .Output("y_min: float")
   2850     .Output("y_max: float")
   2851     .Attr("T: quantizedtype")
   2852     .Attr("output_range_given: bool = false")
   2853     .Attr("given_y_min: float = 0")
   2854     .Attr("given_y_max: float = 0")
   2855     .Attr("variance_epsilon: float = 1e-5")
   2856     .Attr("min_separation: float = 1e-3")
   2857     .SetShapeFn([](shape_inference::InferenceContext* c) {
   2858       shape_inference::ShapeHandle unused;
   2859       // x should be a rank 4 tensor.
   2860       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
   2861       // Assert x_min and x_max are scalars (rank 0).
   2862       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   2863       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   2864       // y has the same shape as x.
   2865       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   2866       // y_min and y_max are scalars.
   2867       c->set_output(1, c->Scalar());
   2868       c->set_output(2, c->Scalar());
   2869       return Status::OK();
   2870     });
   2871 
   2872 namespace {
   2873 
   2874 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
   2875                             ShapeHandle updates_shape,
   2876                             ShapeHandle output_shape) {
   2877   if (c->Value(c->NumElements(output_shape)) == 0 &&
   2878       (c->Value(c->NumElements(indices_shape)) > 0 ||
   2879        c->Value(c->NumElements(updates_shape)) > 0)) {
   2880     return errors::InvalidArgument(
   2881         "Indices and updates specified for empty output shape");
   2882   }
   2883 
   2884   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
   2885     const int64 outer_dims = c->Rank(indices_shape) - 1;
   2886     const DimensionHandle ixdim = c->Dim(indices_shape, -1);
   2887 
   2888     // We can only do more validation if the last dimension of indices
   2889     // is a known value.
   2890     if (c->ValueKnown(ixdim)) {
   2891       int64 ix = c->Value(ixdim);
   2892       ShapeHandle unused;
   2893       ShapeHandle prefix_indices;
   2894       TF_RETURN_IF_ERROR(
   2895           c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
   2896       ShapeHandle prefix_updates;
   2897       TF_RETURN_IF_ERROR(
   2898           c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
   2899 
   2900       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
   2901       if (!s.ok()) {
   2902         return errors::InvalidArgument(
   2903             "The outer ", outer_dims,
   2904             " dimensions of indices.shape=", c->DebugString(indices_shape),
   2905             " must match the outer ", outer_dims,
   2906             " dimensions of updates.shape=", c->DebugString(updates_shape),
   2907             ": ", s.error_message());
   2908       }
   2909 
   2910       ShapeHandle suffix_output;
   2911       TF_RETURN_IF_ERROR(c->Subshape(output_shape, ix, &suffix_output));
   2912       ShapeHandle suffix_updates;
   2913       TF_RETURN_IF_ERROR(
   2914           c->Subshape(updates_shape, outer_dims, &suffix_updates));
   2915       s = c->Merge(suffix_output, suffix_updates, &unused);
   2916       if (!s.ok()) {
   2917         return errors::InvalidArgument(
   2918             "The inner ", c->Rank(output_shape) - ix,
   2919             " dimensions of output.shape=", c->DebugString(output_shape),
   2920             " must match the inner ", c->Rank(updates_shape) - outer_dims,
   2921             " dimensions of updates.shape=", c->DebugString(updates_shape),
   2922             ": ", s.error_message());
   2923       }
   2924     }
   2925   }
   2926 
   2927   c->set_output(0, output_shape);
   2928   return Status::OK();
   2929 }
   2930 
   2931 Status ScatterNdShape(InferenceContext* c) {
   2932   ShapeHandle indices_shape;
   2933   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
   2934   ShapeHandle updates_shape;
   2935   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
   2936   ShapeHandle output_shape;
   2937   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
   2938   return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
   2939 }
   2940 
   2941 Status ScatterNdTensorShape(InferenceContext* c) {
   2942   ShapeHandle output_shape;
   2943   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
   2944   ShapeHandle indices_shape;
   2945   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
   2946   ShapeHandle updates_shape;
   2947   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
   2948   return ScatterNdShapeHelper(c, indices_shape, updates_shape, output_shape);
   2949 }
   2950 
   2951 }  // namespace
   2952 
   2953 REGISTER_OP("UpperBound")
   2954     .Input("sorted_inputs: T")
   2955     .Input("values: T")
   2956     .Output("output: out_type")
   2957     .Attr("T: type")
   2958     .Attr("out_type: {int32, int64} = DT_INT32")
   2959     .SetShapeFn([](InferenceContext* c) {
   2960       ShapeHandle unused_shape;
   2961       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
   2962       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
   2963       c->set_output(0, c->input(1));
   2964       return Status::OK();
   2965     });
   2966 
   2967 REGISTER_OP("LowerBound")
   2968     .Input("sorted_inputs: T")
   2969     .Input("values: T")
   2970     .Output("output: out_type")
   2971     .Attr("T: type")
   2972     .Attr("out_type: {int32, int64} = DT_INT32")
   2973     .SetShapeFn([](InferenceContext* c) {
   2974       ShapeHandle unused_shape;
   2975       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
   2976       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
   2977       c->set_output(0, c->input(1));
   2978       return Status::OK();
   2979     });
   2980 
   2981 REGISTER_OP("ScatterNd")
   2982     .Input("indices: Tindices")
   2983     .Input("updates: T")
   2984     .Input("shape: Tindices")
   2985     .Output("output: T")
   2986     .Attr("T: type")
   2987     .Attr("Tindices: {int32, int64}")
   2988     .SetShapeFn(ScatterNdShape);
   2989 
   2990 REGISTER_OP("TensorScatterUpdate")
   2991     .Input("tensor: T")
   2992     .Input("indices: Tindices")
   2993     .Input("updates: T")
   2994     .Output("output: T")
   2995     .Attr("T: type")
   2996     .Attr("Tindices: {int32, int64}")
   2997     .SetShapeFn(ScatterNdTensorShape);
   2998 
   2999 REGISTER_OP("TensorScatterAdd")
   3000     .Input("tensor: T")
   3001     .Input("indices: Tindices")
   3002     .Input("updates: T")
   3003     .Output("output: T")
   3004     .Attr("T: type")
   3005     .Attr("Tindices: {int32, int64}")
   3006     .SetShapeFn(ScatterNdTensorShape);
   3007 
   3008 REGISTER_OP("TensorScatterSub")
   3009     .Input("tensor: T")
   3010     .Input("indices: Tindices")
   3011     .Input("updates: T")
   3012     .Output("output: T")
   3013     .Attr("T: type")
   3014     .Attr("Tindices: {int32, int64}")
   3015     .SetShapeFn(ScatterNdTensorShape);
   3016 
   3017 REGISTER_OP("ScatterNdNonAliasingAdd")
   3018     .Input("input: T")
   3019     .Input("indices: Tindices")
   3020     .Input("updates: T")
   3021     .Output("output: T")
   3022     .Attr("T: {numbertype, bool}")
   3023     .Attr("Tindices: {int32, int64}")
   3024     .SetShapeFn(shape_inference::ScatterNdUpdateShape);
   3025 
   3026 REGISTER_OP("FakeQuantWithMinMaxArgs")
   3027     .Attr("min: float = -6.0")
   3028     .Attr("max: float = 6.0")
   3029     .Attr("num_bits: int = 8")
   3030     .Attr("narrow_range: bool = false")
   3031     .Input("inputs: float")
   3032     .Output("outputs: float")
   3033     .SetShapeFn(shape_inference::UnchangedShape);
   3034 
   3035 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
   3036     .Attr("min: float = -6.0")
   3037     .Attr("max: float = 6.0")
   3038     .Attr("num_bits: int = 8")
   3039     .Attr("narrow_range: bool = false")
   3040     .Input("gradients: float")
   3041     .Input("inputs: float")
   3042     .Output("backprops: float")
   3043     .SetShapeFn(shape_inference::UnchangedShape);
   3044 
   3045 REGISTER_OP("FakeQuantWithMinMaxVars")
   3046     .Attr("num_bits: int = 8")
   3047     .Attr("narrow_range: bool = false")
   3048     .Input("inputs: float")
   3049     .Input("min: float")
   3050     .Input("max: float")
   3051     .Output("outputs: float")
   3052     .SetShapeFn([](InferenceContext* c) {
   3053       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   3054       ShapeHandle unused;
   3055       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   3056       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   3057       return Status::OK();
   3058     });
   3059 
   3060 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
   3061     .Attr("num_bits: int = 8")
   3062     .Attr("narrow_range: bool = false")
   3063     .Input("gradients: float")
   3064     .Input("inputs: float")
   3065     .Input("min: float")
   3066     .Input("max: float")
   3067     .Output("backprops_wrt_input: float")
   3068     .Output("backprop_wrt_min: float")
   3069     .Output("backprop_wrt_max: float")
   3070     .SetShapeFn([](InferenceContext* c) {
   3071       // gradients and inputs are same size.
   3072       ShapeHandle inputs;
   3073       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
   3074 
   3075       // min and max are scalars
   3076       ShapeHandle min_max;
   3077       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
   3078       TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
   3079 
   3080       c->set_output(0, inputs);
   3081       c->set_output(1, min_max);
   3082       c->set_output(2, min_max);
   3083       return Status::OK();
   3084     });
   3085 
   3086 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
   3087     .Attr("num_bits: int = 8")
   3088     .Attr("narrow_range: bool = false")
   3089     .Input("inputs: float")
   3090     .Input("min: float")
   3091     .Input("max: float")
   3092     .Output("outputs: float")
   3093     .SetShapeFn([](InferenceContext* c) {
   3094       ShapeHandle input, min, max;
   3095       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
   3096       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
   3097       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
   3098 
   3099       DimensionHandle unused;
   3100       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
   3101       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
   3102       TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
   3103 
   3104       c->set_output(0, input);
   3105       return Status::OK();
   3106     });
   3107 
   3108 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
   3109     .Attr("num_bits: int = 8")
   3110     .Attr("narrow_range: bool = false")
   3111     .Input("gradients: float")
   3112     .Input("inputs: float")
   3113     .Input("min: float")
   3114     .Input("max: float")
   3115     .Output("backprops_wrt_input: float")
   3116     .Output("backprop_wrt_min: float")
   3117     .Output("backprop_wrt_max: float")
   3118     .SetShapeFn([](InferenceContext* c) {
   3119       ShapeHandle inputs;
   3120       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
   3121       TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
   3122       TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
   3123 
   3124       ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
   3125 
   3126       ShapeHandle min_max;
   3127       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
   3128       TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
   3129       TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
   3130 
   3131       c->set_output(0, inputs);
   3132       c->set_output(1, min_max);
   3133       c->set_output(2, min_max);
   3134       return Status::OK();
   3135     });
   3136 
   3137 #ifdef INTEL_MKL
   3138 REGISTER_OP("_MklConcat")
   3139     .Input("concat_dim: int32")
   3140     .Input("values: N * T")
   3141     .Input("mkl_concat_dim: uint8")
   3142     .Input("mkl_values: N * uint8")
   3143     .Output("output: T")
   3144     .Output("mkl_output: uint8")
   3145     .Attr("N: int >= 2")
   3146     .Attr("T: type")
   3147     .SetShapeFn([](InferenceContext* c) {
   3148       return shape_inference::ConcatShape(c, c->num_inputs() - 3);
   3149     })
   3150     .Doc(R"doc(
   3151 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
   3152 
   3153 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   3154 expected to invoke these operators.
   3155 )doc");
   3156 #endif
   3157 
   3158 // Deprecated op registrations:
   3159 
   3160 // The following can be deleted after 10mar2017.
   3161 REGISTER_OP("BatchMatrixDiag")
   3162     .Input("diagonal: T")
   3163     .Output("output: T")
   3164     .Attr("T: type")
   3165     .Deprecated(14, "Use MatrixDiag")
   3166     .SetShapeFn(shape_inference::UnknownShape);
   3167 REGISTER_OP("BatchMatrixSetDiag")
   3168     .Input("input: T")
   3169     .Input("diagonal: T")
   3170     .Output("output: T")
   3171     .Attr("T: type")
   3172     .Deprecated(14, "Use MatrixSetDiag")
   3173     .SetShapeFn(shape_inference::UnknownShape);
   3174 REGISTER_OP("BatchMatrixDiagPart")
   3175     .Input("input: T")
   3176     .Output("diagonal: T")
   3177     .Attr("T: type")
   3178     .Deprecated(14, "Use MatrixDiagPart")
   3179     .SetShapeFn(shape_inference::UnknownShape);
   3180 REGISTER_OP("BatchMatrixBandPart")
   3181     .Input("input: T")
   3182     .Input("num_lower: int64")
   3183     .Input("num_upper: int64")
   3184     .Output("band: T")
   3185     .Attr("T: type")
   3186     .Deprecated(14, "Use MatrixBandPart")
   3187     .SetShapeFn(shape_inference::UnknownShape);
   3188 
   3189 }  // namespace tensorflow
   3190