Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/numeric_op.h"
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/shape_inference.h"
     20 
     21 namespace tensorflow {
     22 
     23 using shape_inference::DimensionHandle;
     24 using shape_inference::InferenceContext;
     25 using shape_inference::ShapeHandle;
     26 
     27 REGISTER_OP("AddN")
     28     .Input("inputs: N * T")
     29     .Output("sum: T")
     30     .Attr("N: int >= 1")
     31     .Attr("T: {numbertype, variant}")
     32     .SetIsCommutative()
     33     .SetIsAggregate()
     34     .SetShapeFn([](InferenceContext* c) {
     35       ShapeHandle cur = c->input(c->num_inputs() - 1);
     36       for (int i = c->num_inputs() - 2; i >= 0; --i) {
     37         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
     38                                         "From merging shape ", i,
     39                                         " with other shapes.");
     40       }
     41       c->set_output(0, cur);
     42       return Status::OK();
     43     });
     44 
     45 // --------------------------------------------------------------------------
     46 
     47 // Note that the following operator is just a placeholder and has no
     48 // associated kernel. The code in accumulate_n_optimizer.cc replaces
     49 // this placeholder with a graph of operators that do have kernels.
     50 // The Python code that generates instances of this op is currently in
     51 // contrib/framework/python/ops/accumulate_n_v2.py
     52 REGISTER_OP("AccumulateNV2")
     53     .Input("inputs: N * T")
     54     .Output("sum: T")
     55     .Attr("N: int >= 1")
     56     .Attr("T: numbertype")
     57     .Attr("shape: shape")
     58     .SetIsCommutative()
     59     .SetIsAggregate()
     60     .SetShapeFn(shape_inference::ExplicitShape);
     61 
     62 // --------------------------------------------------------------------------
     63 
     64 REGISTER_OP("BatchMatMul")
     65     .Input("x: T")
     66     .Input("y: T")
     67     .Output("output: T")
     68     .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
     69     .Attr("adj_x: bool = false")
     70     .Attr("adj_y: bool = false")
     71     .SetShapeFn([](InferenceContext* c) {
     72       ShapeHandle a_shape;
     73       ShapeHandle b_shape;
     74       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
     75       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
     76 
     77       // Determine output rows and cols.
     78       bool adj_x;
     79       bool adj_y;
     80       TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
     81       TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
     82       DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
     83       DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
     84 
     85       // Batch dims match between inputs.
     86       ShapeHandle a_batch_dims;
     87       ShapeHandle b_batch_dims;
     88       ShapeHandle batch_dims;
     89       TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
     90       TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
     91       TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
     92 
     93       // Assert inner dims match.
     94       DimensionHandle unused;
     95       TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
     96                                   c->Dim(b_shape, adj_y ? -1 : -2), &unused));
     97 
     98       ShapeHandle out;
     99       TF_RETURN_IF_ERROR(c->Concatenate(
    100           batch_dims, c->Matrix(output_rows, output_cols), &out));
    101       c->set_output(0, out);
    102       return Status::OK();
    103     });
    104 
    105 // --------------------------------------------------------------------------
    106 // Casting Ops
    107 //
    108 // NOTE: Only a smaller number of types are supported by
    109 // Cast. The exact casting rule is TBD. The current
    110 // implementation uses C++ static cast rules for numeric
    111 // types, which may be changed in the future.
    112 REGISTER_OP("Cast")
    113     .Input("x: SrcT")
    114     .Output("y: DstT")
    115     .Attr("SrcT: type")
    116     .Attr("DstT: type")
    117     .SetShapeFn(shape_inference::UnchangedShape);
    118 
    119 REGISTER_OP("_HostCast")
    120     .Input("x: SrcT")
    121     .Output("y: DstT")
    122     .Attr("SrcT: type")
    123     .Attr("DstT: type")
    124     .SetShapeFn(shape_inference::UnchangedShape)
    125     .Doc(R"doc(
    126 Cast x of type SrcT to y of DstT.
    127 
    128 _HostCast requires its input and produces its output in host memory.
    129 )doc");
    130 
    131 // --------------------------------------------------------------------------
    132 
    133 REGISTER_OP("Abs")
    134     .Input("x: T")
    135     .Output("y: T")
    136     .Attr("T: {half, bfloat16, float, double, int32, int64}")
    137     .SetShapeFn(shape_inference::UnchangedShape);
    138 
    139 REGISTER_OP("ComplexAbs")
    140     .Input("x: T")
    141     .Output("y: Tout")
    142     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
    143     .Attr("Tout: {float, double} = DT_FLOAT")
    144     .SetShapeFn(shape_inference::UnchangedShape);
    145 
    146 // Declares cwise unary operations signature: 't -> 't
    147 #define UNARY()                                                          \
    148   Input("x: T")                                                          \
    149       .Output("y: T")                                                    \
    150       .Attr(                                                             \
    151           "T: {half, bfloat16, float, double, int32, int64, complex64, " \
    152           "complex128}")                                                 \
    153       .SetShapeFn(shape_inference::UnchangedShape)
    154 
    155 #define UNARY_REAL()                              \
    156   Input("x: T")                                   \
    157       .Output("y: T")                             \
    158       .Attr("T: {half, bfloat16, float, double}") \
    159       .SetShapeFn(shape_inference::UnchangedShape)
    160 
    161 #define UNARY_COMPLEX()                                                  \
    162   Input("x: T")                                                          \
    163       .Output("y: T")                                                    \
    164       .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
    165       .SetShapeFn(shape_inference::UnchangedShape)
    166 
    167 #define UNARY_GRADIENT_COMPLEX()                                         \
    168   Input("y: T")                                                          \
    169       .Input("dy: T")                                                    \
    170       .Output("z: T")                                                    \
    171       .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
    172       .SetShapeFn(shape_inference::UnchangedShape)
    173 
    174 REGISTER_OP("Neg").UNARY();
    175 
    176 REGISTER_OP("Inv").UNARY();
    177 
    178 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
    179 
    180 REGISTER_OP("Reciprocal").UNARY();
    181 
    182 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
    183 
    184 REGISTER_OP("Square").UNARY();
    185 
    186 REGISTER_OP("Sqrt").UNARY_COMPLEX();
    187 
    188 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
    189 
    190 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
    191 
    192 REGISTER_OP("Round").UNARY();
    193 
    194 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
    195 
    196 REGISTER_OP("Exp").UNARY_COMPLEX();
    197 
    198 REGISTER_OP("Expm1").UNARY_COMPLEX();
    199 
    200 REGISTER_OP("Log").UNARY_COMPLEX();
    201 
    202 REGISTER_OP("Log1p").UNARY_COMPLEX();
    203 
    204 REGISTER_OP("Sinh").UNARY_COMPLEX();
    205 
    206 REGISTER_OP("Cosh").UNARY_COMPLEX();
    207 
    208 REGISTER_OP("Tanh").UNARY_COMPLEX();
    209 
    210 REGISTER_OP("Asinh").UNARY_COMPLEX();
    211 
    212 REGISTER_OP("Acosh").UNARY_COMPLEX();
    213 
    214 REGISTER_OP("Atanh").UNARY_COMPLEX();
    215 
    216 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
    217 
    218 REGISTER_OP("Lgamma").UNARY_REAL();
    219 
    220 REGISTER_OP("Digamma").UNARY_REAL();
    221 
    222 REGISTER_OP("Erf").UNARY_REAL();
    223 
    224 REGISTER_OP("Erfc").UNARY_REAL();
    225 
    226 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
    227 
    228 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
    229 
    230 REGISTER_OP("Sin").UNARY_COMPLEX();
    231 
    232 REGISTER_OP("Cos").UNARY_COMPLEX();
    233 
    234 REGISTER_OP("Tan").UNARY();
    235 
    236 REGISTER_OP("Asin").UNARY();
    237 
    238 REGISTER_OP("Acos").UNARY();
    239 
    240 REGISTER_OP("Atan").UNARY();
    241 
    242 #undef UNARY
    243 #undef UNARY_REAL
    244 #undef UNARY_COMPLEX
    245 
    246 REGISTER_OP("IsNan")
    247     .Input("x: T")
    248     .Output("y: bool")
    249     .Attr("T: {half, bfloat16, float, double}")
    250     .SetShapeFn(shape_inference::UnchangedShape);
    251 
    252 REGISTER_OP("IsInf")
    253     .Input("x: T")
    254     .Output("y: bool")
    255     .Attr("T: {half, bfloat16, float, double}")
    256     .SetShapeFn(shape_inference::UnchangedShape);
    257 
    258 REGISTER_OP("IsFinite")
    259     .Input("x: T")
    260     .Output("y: bool")
    261     .Attr("T: {half, bfloat16, float, double}")
    262     .SetShapeFn(shape_inference::UnchangedShape);
    263 
    264 REGISTER_OP("Sign")
    265     .Input("x: T")
    266     .Output("y: T")
    267     .Attr(
    268         "T: {half, bfloat16, float, double, int32, int64, complex64, "
    269         "complex128}")
    270     .SetShapeFn(shape_inference::UnchangedShape);
    271 
    272 REGISTER_OP("Floor")
    273     .Input("x: T")
    274     .Output("y: T")
    275     .Attr("T: {half, bfloat16, float, double}")
    276     .SetShapeFn(shape_inference::UnchangedShape);
    277 
    278 REGISTER_OP("Ceil")
    279     .Input("x: T")
    280     .Output("y: T")
    281     .Attr("T: {half, bfloat16, float, double}")
    282     .SetShapeFn(shape_inference::UnchangedShape);
    283 
    284 REGISTER_OP("Rint")
    285     .Input("x: T")
    286     .Output("y: T")
    287     .Attr("T: {bfloat16, float, double}")
    288     .SetShapeFn(shape_inference::UnchangedShape);
    289 
    290 // Declares cwise binary operations signature: 't, 't -> 't.
    291 
    292 #define BINARY_MORE()                                                          \
    293   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
    294       "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \
    295       "int64, complex64, complex128}")
    296 
    297 #define BINARY_FEWER()                                               \
    298   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
    299       "T: {half, bfloat16, float, double, int32, int64, complex64, " \
    300       "complex128}")
    301 
    302 REGISTER_OP("Add")
    303     .Input("x: T")
    304     .Input("y: T")
    305     .Output("z: T")
    306     .Attr(
    307         "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
    308         "complex64, complex128, string}")
    309     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    310 
    311 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
    312 // use AddV2 (b/68646025).
    313 REGISTER_OP("AddV2")
    314     .Input("x: T")
    315     .Input("y: T")
    316     .Output("z: T")
    317     .Attr(
    318         "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
    319         "complex64, complex128}")
    320     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    321     .SetIsAggregate()
    322     .SetIsCommutative();
    323 
    324 REGISTER_OP("_MklAdd")
    325     .Input("x: T")
    326     .Input("y: T")
    327     .Input("mkl_x: uint8")
    328     .Input("mkl_y: uint8")
    329     .Output("z: T")
    330     .Output("mkl_z: uint8")
    331     .Attr(
    332         "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
    333         "complex128, string}")
    334     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    335     .Doc(R"doc(
    336 Returns x + y element-wise.
    337 
    338 *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
    339 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    340 )doc");
    341 
    342 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
    343     shape_inference::BroadcastBinaryOpShapeFn);
    344 
    345 REGISTER_OP("_MklSub")
    346     .BINARY_FEWER()
    347     .Input("mkl_x: uint8")
    348     .Input("mkl_y: uint8")
    349     .Output("mkl_z: uint8")
    350     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    351     .Doc(R"doc(
    352 Returns x - y element-wise.
    353 
    354 *NOTE*: `Sub` supports broadcasting. More about broadcasting
    355 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    356 )doc");
    357 
    358 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
    359     shape_inference::BroadcastBinaryOpShapeFn);
    360 
    361 REGISTER_OP("_MklMul")
    362     .BINARY_MORE()
    363     .Input("mkl_x: uint8")
    364     .Input("mkl_y: uint8")
    365     .Output("mkl_z: uint8")
    366     .SetIsCommutative()
    367     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    368     .Doc(R"doc(
    369 Returns x * y element-wise.
    370 
    371 *NOTE*: `Mul` supports broadcasting. More about broadcasting
    372 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    373 )doc");
    374 
    375 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
    376     shape_inference::BroadcastBinaryOpShapeFn);
    377 
    378 REGISTER_OP("FloorDiv")
    379     .BINARY_MORE()
    380     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    381 
    382 REGISTER_OP("TruncateDiv")
    383     .BINARY_MORE()
    384     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    385 
    386 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
    387     shape_inference::BroadcastBinaryOpShapeFn);
    388 
    389 REGISTER_OP("SquaredDifference")
    390     .BINARY_FEWER()
    391     .SetIsCommutative()
    392     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    393 
    394 REGISTER_OP("_MklSquaredDifference")
    395     .BINARY_FEWER()
    396     .Input("mkl_x: uint8")
    397     .Input("mkl_y: uint8")
    398     .Output("mkl_z: uint8")
    399     .SetIsCommutative()
    400     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    401     .Doc(R"doc(
    402 Returns (x - y)(x - y) element-wise.
    403 
    404 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
    405 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    406 )doc");
    407 
    408 #undef BINARY_FEWER
    409 #undef BINARY_MORE
    410 
    411 REGISTER_OP("Maximum")
    412     .Input("x: T")
    413     .Input("y: T")
    414     .Output("z: T")
    415     .Attr("T: {half, bfloat16, float, double, int32, int64}")
    416     .SetIsCommutative()
    417     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    418 
    419 REGISTER_OP("_MklMaximum")
    420     .Input("x: T")
    421     .Input("y: T")
    422     .Input("mkl_x: uint8")
    423     .Input("mkl_y: uint8")
    424     .Output("z: T")
    425     .Output("mkl_z: uint8")
    426     .Attr("T: {half, float, double, int32, int64}")
    427     .SetIsCommutative()
    428     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    429     .Doc(R"doc(
    430 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
    431 
    432 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
    433 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    434 )doc");
    435 
    436 REGISTER_OP("Minimum")
    437     .Input("x: T")
    438     .Input("y: T")
    439     .Output("z: T")
    440     .Attr("T: {half, bfloat16, float, double, int32, int64}")
    441     .SetIsCommutative()
    442     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    443 
    444 REGISTER_OP("Mod")
    445     .Input("x: T")
    446     .Input("y: T")
    447     .Output("z: T")
    448     .Attr("T: {int32, int64, bfloat16, float, double}")
    449     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    450 
    451 REGISTER_OP("FloorMod")
    452     .Input("x: T")
    453     .Input("y: T")
    454     .Output("z: T")
    455     .Attr("T: {int32, int64, bfloat16, float, double}")
    456     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    457 
    458 REGISTER_OP("TruncateMod")
    459     .Input("x: T")
    460     .Input("y: T")
    461     .Output("z: T")
    462     .Attr("T: {int32, int64, bfloat16, float, double}")
    463     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    464 
    465 REGISTER_OP("Pow")
    466     .Input("x: T")
    467     .Input("y: T")
    468     .Output("z: T")
    469     .Attr(
    470         "T: {half, bfloat16, float, double, int32, int64, complex64, "
    471         "complex128}")
    472     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    473 
    474 REGISTER_OP("Igammac")
    475     .Input("a: T")
    476     .Input("x: T")
    477     .Output("z: T")
    478     .Attr("T: {float, double}")
    479     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    480 
    481 REGISTER_OP("Igamma")
    482     .Input("a: T")
    483     .Input("x: T")
    484     .Output("z: T")
    485     .Attr("T: {float, double}")
    486     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    487 
    488 REGISTER_OP("Zeta")
    489     .Input("x: T")
    490     .Input("q: T")
    491     .Output("z: T")
    492     .Attr("T: {float, double}")
    493     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    494 
    495 REGISTER_OP("Polygamma")
    496     .Input("a: T")
    497     .Input("x: T")
    498     .Output("z: T")
    499     .Attr("T: {float, double}")
    500     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    501 
    502 REGISTER_OP("Atan2")
    503     .Input("y: T")
    504     .Input("x: T")
    505     .Output("z: T")
    506     .Attr("T: {bfloat16, float, double}")
    507     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    508 
    509 REGISTER_OP("Betainc")
    510     .Input("a: T")
    511     .Input("b: T")
    512     .Input("x: T")
    513     .Output("z: T")
    514     .Attr("T: {float, double}")
    515     .SetShapeFn([](InferenceContext* c) {
    516       const int num_inputs = 3;
    517       ShapeHandle output = c->UnknownShape();
    518       int num_scalars = 0;
    519       ShapeHandle some_non_scalar;
    520       for (int i = 0; i < num_inputs; ++i) {
    521         ShapeHandle in = c->input(i);
    522         if (!c->RankKnown(in)) {
    523           some_non_scalar = in;
    524           // An input with unknown rank could be either a scalar (to be
    525           // broadcast) or some other shape.
    526         } else if (c->Rank(in) == 0) {
    527           // Input is a scalar, it will be broadcast to the output shape.
    528           ++num_scalars;
    529         } else {
    530           TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
    531           some_non_scalar = output;
    532         }
    533       }
    534 
    535       if (num_scalars == num_inputs - 1) {
    536         // If all but one input is known to be a scalar, then output is the
    537         // remaining input.
    538         output = some_non_scalar;
    539       } else if (num_scalars == num_inputs) {
    540         // If all are scalars, output is scalar; pick the first one arbitrarily.
    541         output = c->input(0);
    542       }
    543 
    544       c->set_output(0, output);
    545       return Status::OK();
    546     });
    547 
    548 // --------------------------------------------------------------------------
    549 
    550 // Declares cwise binary comparison operations signature: 't, 't -> bool,
    551 // where 't has a natural total order.
    552 #define COMPARISON()             \
    553   Input("x: T")                  \
    554       .Input("y: T")             \
    555       .Output("z: bool")         \
    556       .Attr("T: realnumbertype") \
    557       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    558 
    559 REGISTER_OP("Less").COMPARISON();
    560 
    561 REGISTER_OP("LessEqual").COMPARISON();
    562 
    563 REGISTER_OP("Greater").COMPARISON();
    564 
    565 REGISTER_OP("GreaterEqual").COMPARISON();
    566 
    567 #undef COMPARISON
    568 
    569 // --------------------------------------------------------------------------
    570 
    571 #define EQUALITY_COMPARISON()                                              \
    572   Input("x: T")                                                            \
    573       .Input("y: T")                                                       \
    574       .Output("z: bool")                                                   \
    575       .SetIsCommutative()                                                  \
    576       .Attr(                                                               \
    577           "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \
    578           "int64, complex64, quint8, qint8, qint32, string, bool, "        \
    579           "complex128}")                                                   \
    580       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    581 
    582 REGISTER_OP("Equal").EQUALITY_COMPARISON();
    583 
    584 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
    585 
    586 #undef EQUALITY_COMPARISON
    587 
    588 REGISTER_OP("ApproximateEqual")
    589     .Input("x: T")
    590     .Input("y: T")
    591     .Output("z: bool")
    592     .SetIsCommutative()
    593     .Attr("T: numbertype")
    594     .Attr("tolerance: float = 0.00001")
    595     .SetShapeFn(shape_inference::UnchangedShape);
    596 
    597 // --------------------------------------------------------------------------
    598 
    599 REGISTER_OP("LogicalNot")
    600     .Input("x: bool")
    601     .Output("y: bool")
    602     .SetShapeFn(shape_inference::UnchangedShape);
    603 
    604 #define BINARY_LOGICAL()  \
    605   Input("x: bool")        \
    606       .Input("y: bool")   \
    607       .Output("z: bool")  \
    608       .SetIsCommutative() \
    609       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    610 
    611 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
    612 
    613 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
    614 
    615 #undef BINARY_LOGICAL
    616 
    617 // --------------------------------------------------------------------------
    618 
    619 REGISTER_OP("Select")
    620     .Input("condition: bool")
    621     .Input("t: T")
    622     .Input("e: T")
    623     .Output("output: T")
    624     .Attr("T: type")
    625     .SetShapeFn([](InferenceContext* c) {
    626       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
    627       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
    628       // Merge handle shape and dtype if applicable.
    629       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
    630         const auto size = handle_data_1->size();
    631         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
    632         if (size != handle_data_2->size()) {
    633           return errors::InvalidArgument(
    634               "Trying to merge handles pointing to different numbers of "
    635               "tensors.");
    636         }
    637 
    638         for (int i = 0; i < size; ++i) {
    639           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
    640           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
    641           if (s1.dtype != s2.dtype) {
    642             // TODO(apassos) resolve this in the manner of b/32476923
    643             return errors::InvalidArgument(
    644                 "Trying to merge handles pointing to different dtypes.");
    645           }
    646           merged_handle_data[i].dtype = s1.dtype;
    647           TF_RETURN_IF_ERROR(
    648               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
    649         }
    650 
    651         c->set_output_handle_shapes_and_types(0, merged_handle_data);
    652       }
    653 
    654       // The inputs 'then' and 'else' must have the same shape.
    655       ShapeHandle data = c->input(1);
    656       ShapeHandle other = c->input(2);
    657       TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
    658 
    659       // The input 'cond' must either have the same shape as 'then' and
    660       // 'else', or be a vector if 'then' and 'else' are at least vectors.
    661       ShapeHandle cond = c->input(0);
    662 
    663       if (!c->RankKnown(cond) || !c->RankKnown(data)) {
    664         c->set_output(0, data);
    665         return Status::OK();
    666       }
    667 
    668       // rank of shape and data is known.
    669 
    670       const int32 cond_rank = c->Rank(cond);
    671       const int32 data_rank = c->Rank(data);
    672 
    673       if (cond_rank == 0) {
    674         // The rank of 'cond' is a scalar.
    675         // t and e can have any shape.
    676         c->set_output(0, data);
    677         return Status::OK();
    678       }
    679 
    680       if (cond_rank != 1) {
    681         // If 'cond' is not a vector, and not a scalar,
    682         // then shape must match 'then' and 'else'
    683         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
    684         c->set_output(0, data);
    685         return Status::OK();
    686       }
    687 
    688       if (data_rank == 0) {
    689         // if 'then' and 'else' are scalar also the cond must be
    690         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
    691         c->set_output(0, data);
    692         return Status::OK();
    693       }
    694 
    695       if (cond_rank == 1) {
    696         // if the cond is a vector and the 'then' is not a scalar,
    697         // the first dimension of 'then' and 'else'
    698         TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
    699         c->set_output(0, data);
    700         return Status::OK();
    701       }
    702 
    703       c->set_output(0, data);
    704 
    705       return Status::OK();
    706     });
    707 
    708 // --------------------------------------------------------------------------
    709 
    710 REGISTER_OP("MatMul")
    711     .Input("a: T")
    712     .Input("b: T")
    713     .Output("product: T")
    714     .Attr("transpose_a: bool = false")
    715     .Attr("transpose_b: bool = false")
    716     .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
    717     .SetShapeFn(shape_inference::MatMulShape);
    718 
    719 REGISTER_OP("SparseMatMul")
    720     .Input("a: Ta")
    721     .Input("b: Tb")
    722     .Output("product: float")
    723     .Attr("transpose_a: bool = false")
    724     .Attr("transpose_b: bool = false")
    725     .Attr("a_is_sparse: bool = false")
    726     .Attr("b_is_sparse: bool = false")
    727     .Attr("Ta: {float, bfloat16} = DT_FLOAT")
    728     .Attr("Tb: {float, bfloat16} = DT_FLOAT")
    729     .SetShapeFn(shape_inference::MatMulShape);
    730 
    731 // --------------------------------------------------------------------------
    732 
    733 // For operations where the output is a reduction function along some
    734 // dimensions of the input.
    735 REGISTER_OP("Sum")
    736     .Input("input: T")
    737     .Input("reduction_indices: Tidx")
    738     .Output("output: T")
    739     .Attr("keep_dims: bool = false")
    740     .Attr("T: numbertype")
    741     .Attr("Tidx: {int32, int64} = DT_INT32")
    742     .SetShapeFn(shape_inference::ReductionShape);
    743 
    744 REGISTER_OP("Mean")
    745     .Input("input: T")
    746     .Input("reduction_indices: Tidx")
    747     .Output("output: T")
    748     .Attr("keep_dims: bool = false")
    749     .Attr("T: numbertype")
    750     .Attr("Tidx: {int32, int64} = DT_INT32")
    751     .SetShapeFn(shape_inference::ReductionShape);
    752 
    753 REGISTER_OP("Prod")
    754     .Input("input: T")
    755     .Input("reduction_indices: Tidx")
    756     .Output("output: T")
    757     .Attr("keep_dims: bool = false")
    758     .Attr("T: numbertype")
    759     .Attr("Tidx: {int32, int64} = DT_INT32")
    760     .SetShapeFn(shape_inference::ReductionShape);
    761 
    762 REGISTER_OP("Min")
    763     .Input("input: T")
    764     .Input("reduction_indices: Tidx")
    765     .Output("output: T")
    766     .Attr("keep_dims: bool = false")
    767     .Attr("T: numbertype")
    768     .Attr("Tidx: {int32, int64} = DT_INT32")
    769     .SetShapeFn(shape_inference::ReductionShape);
    770 
    771 REGISTER_OP("Max")
    772     .Input("input: T")
    773     .Input("reduction_indices: Tidx")
    774     .Output("output: T")
    775     .Attr("keep_dims: bool = false")
    776     .Attr("T: numbertype")
    777     .Attr("Tidx: {int32, int64} = DT_INT32")
    778     .SetShapeFn(shape_inference::ReductionShape);
    779 
    780 namespace {
    781 
    782 Status ArgOpShape(shape_inference::InferenceContext* c) {
    783   ShapeHandle dimension_shape;
    784   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
    785 
    786   ShapeHandle input_shape = c->input(0);
    787   if (!c->RankKnown(input_shape)) {
    788     return shape_inference::UnknownShape(c);
    789   }
    790 
    791   const int32 input_rank = c->Rank(input_shape);
    792   if (input_rank <= 1) {
    793     // Reducing a scalar/vector must return a scalar.
    794     return shape_inference::ScalarShape(c);
    795   }
    796 
    797   const Tensor* dim_t = c->input_tensor(1);
    798   if (dim_t == nullptr) {
    799     // We don't know the value of the dimension, but we
    800     // know the rank of the input, so return the correct
    801     // rank with unknown dimensions.
    802     std::vector<DimensionHandle> dims(input_rank - 1);
    803     for (int i = 0; i < dims.size(); ++i) {
    804       dims[i] = c->UnknownDim();
    805     }
    806 
    807     c->set_output(0, c->MakeShape(dims));
    808     return Status::OK();
    809   }
    810 
    811   int64 dimension_val;
    812   if (dim_t->dtype() == DT_INT32) {
    813     dimension_val = dim_t->scalar<int32>()();
    814   } else {
    815     dimension_val = dim_t->scalar<int64>()();
    816   }
    817 
    818   int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
    819   if (axis < 0 || axis >= input_rank) {
    820     return errors::InvalidArgument(
    821         "Dimension (", dimension_val, ") must be in the range [", -input_rank,
    822         ", ", input_rank, "), where ", input_rank,
    823         " is the number of dimensions in the input.");
    824   }
    825 
    826   // Return the input shape without the dimension being reduced.
    827   std::vector<DimensionHandle> dims;
    828   for (int i = 0; i < input_rank; ++i) {
    829     if (axis != i) {
    830       dims.emplace_back(c->Dim(input_shape, i));
    831     }
    832   }
    833   c->set_output(0, c->MakeShape(dims));
    834   return Status::OK();
    835 }
    836 
    837 }  // namespace
    838 
    839 REGISTER_OP("ArgMax")
    840     .Input("input: T")
    841     .Input("dimension: Tidx")
    842     .Output("output: output_type")
    843     .Attr("T: numbertype")
    844     .Attr("Tidx: {int32, int64} = DT_INT32")
    845     .Attr("output_type: {int32, int64} = DT_INT64")
    846     .SetShapeFn(ArgOpShape);
    847 
    848 REGISTER_OP("ArgMin")
    849     .Input("input: T")
    850     .Input("dimension: Tidx")
    851     .Output("output: output_type")
    852     .Attr("T: numbertype")
    853     .Attr("Tidx: {int32, int64} = DT_INT32")
    854     .Attr("output_type: {int32, int64} = DT_INT64")
    855     .SetShapeFn(ArgOpShape);
    856 
    857 namespace {
    858 
    859 Status SegmentReductionShapeFn(InferenceContext* c) {
    860   ShapeHandle data_shape;
    861   ShapeHandle segment_ids_shape;
    862   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
    863   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
    864 
    865   ShapeHandle subshape;
    866   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
    867 
    868   ShapeHandle out;
    869   TF_RETURN_IF_ERROR(
    870       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
    871   c->set_output(0, out);
    872   return Status::OK();
    873 }
    874 
    875 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
    876   ShapeHandle data_shape;
    877   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
    878 
    879   ShapeHandle indices_shape;
    880   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
    881 
    882   ShapeHandle segment_ids_shape;
    883   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
    884 
    885   // indices and segment_ids should merge cleanly.
    886   ShapeHandle unused;
    887   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
    888 
    889   ShapeHandle subshape;
    890   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
    891 
    892   ShapeHandle out;
    893   TF_RETURN_IF_ERROR(
    894       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
    895   c->set_output(0, out);
    896   return Status::OK();
    897 }
    898 
    899 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
    900   ShapeHandle data_shape;
    901   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
    902 
    903   ShapeHandle indices_shape;
    904   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
    905 
    906   // indices and segment_ids should merge cleanly.
    907   ShapeHandle unused;
    908   TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
    909 
    910   // output_dim0 should be a scalar
    911   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
    912 
    913   ShapeHandle subshape;
    914   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
    915 
    916   const Tensor* dim0 = c->input_tensor(3);
    917   ShapeHandle dim0_shape;
    918   if (dim0 == nullptr) {
    919     // We don't have the value at inference time, so the output
    920     // shape is unknown.
    921     dim0_shape = c->Vector(InferenceContext::kUnknownDim);
    922   } else {
    923     auto dim0_value = dim0->scalar<int32>()();
    924     if (dim0_value < 0) {
    925       return errors::InvalidArgument(
    926           "Cannot specify a negative value for output_dim0");
    927     }
    928     dim0_shape = c->Vector(dim0_value);
    929   }
    930 
    931   ShapeHandle out;
    932   TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
    933   c->set_output(0, out);
    934   return Status::OK();
    935 }
    936 
    937 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
    938   ShapeHandle data_shape;
    939   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
    940 
    941   ShapeHandle indices_shape;
    942   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
    943 
    944   ShapeHandle segment_ids_shape;
    945   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
    946 
    947   ShapeHandle num_segments_shape;
    948   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
    949 
    950   // indices and segment_ids should merge cleanly.
    951   ShapeHandle unused;
    952   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
    953 
    954   ShapeHandle subshape;
    955   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
    956 
    957   ShapeHandle out;
    958   const Tensor* dim0 = c->input_tensor(3);
    959   if (dim0 == nullptr) {
    960     // We don't have the value at inference time, so the output
    961     // shape is unknown.
    962     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
    963                                       subshape, &out));
    964   } else {
    965     auto dim0_value = dim0->scalar<int32>()();
    966     if (dim0_value < 0) {
    967       return errors::InvalidArgument(
    968           "Cannot specify a negative value for num_segments");
    969     }
    970     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
    971   }
    972   c->set_output(0, out);
    973   return Status::OK();
    974 }
    975 
    976 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
    977   ShapeHandle s_data = c->input(0);
    978   ShapeHandle s_segment_ids = c->input(1);
    979   ShapeHandle s_num_segments = c->input(2);
    980   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
    981 
    982   ShapeHandle out;
    983 
    984   // Leading dimensions of data must be compatible with dimensions of
    985   // <s_segment_ids>.
    986   if (c->RankKnown(s_segment_ids)) {
    987     TF_RETURN_IF_ERROR(
    988         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
    989 
    990     // Get the value of the num_segments input tensor.
    991     DimensionHandle num_segments_dim;
    992     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
    993 
    994     // Output is {segment_id_rank} + s_data[segment_id_rank:].
    995     ShapeHandle s_data_suffix;
    996     TF_RETURN_IF_ERROR(
    997         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
    998     TF_RETURN_IF_ERROR(
    999         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
   1000   } else {
   1001     out = c->UnknownShape();
   1002   }
   1003   c->set_output(0, out);
   1004   return Status::OK();
   1005 }
   1006 }  // namespace
   1007 
   1008 REGISTER_OP("SegmentSum")
   1009     .Input("data: T")
   1010     .Input("segment_ids: Tindices")
   1011     .Output("output: T")
   1012     .Attr("T: numbertype")
   1013     .Attr("Tindices: {int32,int64}")
   1014     .SetShapeFn(SegmentReductionShapeFn);
   1015 
   1016 REGISTER_OP("SegmentMean")
   1017     .Input("data: T")
   1018     .Input("segment_ids: Tindices")
   1019     .Output("output: T")
   1020     .Attr("T: realnumbertype")
   1021     .Attr("Tindices: {int32,int64}")
   1022     .SetShapeFn(SegmentReductionShapeFn);
   1023 
   1024 REGISTER_OP("SegmentProd")
   1025     .Input("data: T")
   1026     .Input("segment_ids: Tindices")
   1027     .Output("output: T")
   1028     .Attr("T: numbertype")
   1029     .Attr("Tindices: {int32,int64}")
   1030     .SetShapeFn(SegmentReductionShapeFn);
   1031 
   1032 REGISTER_OP("SegmentMin")
   1033     .Input("data: T")
   1034     .Input("segment_ids: Tindices")
   1035     .Output("output: T")
   1036     .Attr("T: realnumbertype")
   1037     .Attr("Tindices: {int32,int64}")
   1038     .SetShapeFn(SegmentReductionShapeFn);
   1039 
   1040 REGISTER_OP("SegmentMax")
   1041     .Input("data: T")
   1042     .Input("segment_ids: Tindices")
   1043     .Output("output: T")
   1044     .Attr("T: realnumbertype")
   1045     .Attr("Tindices: {int32,int64}")
   1046     .SetShapeFn(SegmentReductionShapeFn);
   1047 
   1048 REGISTER_OP("UnsortedSegmentSum")
   1049     .Input("data: T")
   1050     .Input("segment_ids: Tindices")
   1051     .Input("num_segments: Tnumsegments")
   1052     .Output("output: T")
   1053     .Attr("T: numbertype")
   1054     .Attr("Tindices: {int32,int64}")
   1055     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1056     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1057 
   1058 REGISTER_OP("UnsortedSegmentMax")
   1059     .Input("data: T")
   1060     .Input("segment_ids: Tindices")
   1061     .Input("num_segments: Tnumsegments")
   1062     .Output("output: T")
   1063     .Attr("T: realnumbertype")
   1064     .Attr("Tindices: {int32,int64}")
   1065     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1066     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1067 
   1068 REGISTER_OP("UnsortedSegmentMin")
   1069     .Input("data: T")
   1070     .Input("segment_ids: Tindices")
   1071     .Input("num_segments: Tnumsegments")
   1072     .Output("output: T")
   1073     .Attr("T: realnumbertype")
   1074     .Attr("Tindices: {int32,int64}")
   1075     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1076     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1077 
   1078 REGISTER_OP("UnsortedSegmentProd")
   1079     .Input("data: T")
   1080     .Input("segment_ids: Tindices")
   1081     .Input("num_segments: Tnumsegments")
   1082     .Output("output: T")
   1083     .Attr("T: realnumbertype")
   1084     .Attr("Tindices: {int32,int64}")
   1085     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1086     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1087 
   1088 REGISTER_OP("SparseSegmentSum")
   1089     .Input("data: T")
   1090     .Input("indices: Tidx")
   1091     .Input("segment_ids: int32")
   1092     .Output("output: T")
   1093     .Attr("T: realnumbertype")
   1094     .Attr("Tidx: {int32, int64} = DT_INT32")
   1095     .SetShapeFn(SparseSegmentReductionShapeFn);
   1096 
   1097 REGISTER_OP("SparseSegmentSumWithNumSegments")
   1098     .Input("data: T")
   1099     .Input("indices: Tidx")
   1100     .Input("segment_ids: int32")
   1101     .Input("num_segments: Tnumsegments")
   1102     .Output("output: T")
   1103     .Attr("T: realnumbertype")
   1104     .Attr("Tidx: {int32, int64} = DT_INT32")
   1105     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1106     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1107 
   1108 REGISTER_OP("SparseSegmentMean")
   1109     .Input("data: T")
   1110     .Input("indices: Tidx")
   1111     .Input("segment_ids: int32")
   1112     .Output("output: T")
   1113     .Attr("T: {float, double}")
   1114     .Attr("Tidx: {int32, int64} = DT_INT32")
   1115     .SetShapeFn(SparseSegmentReductionShapeFn);
   1116 
   1117 REGISTER_OP("SparseSegmentMeanWithNumSegments")
   1118     .Input("data: T")
   1119     .Input("indices: Tidx")
   1120     .Input("segment_ids: int32")
   1121     .Input("num_segments: Tnumsegments")
   1122     .Output("output: T")
   1123     .Attr("T: {float, double}")
   1124     .Attr("Tidx: {int32, int64} = DT_INT32")
   1125     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1126     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1127 
   1128 REGISTER_OP("SparseSegmentMeanGrad")
   1129     .Input("grad: T")
   1130     .Input("indices: Tidx")
   1131     .Input("segment_ids: int32")
   1132     .Input("output_dim0: int32")
   1133     .Output("output: T")
   1134     .Attr("T: {float, double}")
   1135     .Attr("Tidx: {int32, int64} = DT_INT32")
   1136     .SetShapeFn(SparseSegmentReductionGradShapeFn);
   1137 
   1138 REGISTER_OP("SparseSegmentSqrtN")
   1139     .Input("data: T")
   1140     .Input("indices: Tidx")
   1141     .Input("segment_ids: int32")
   1142     .Output("output: T")
   1143     .Attr("T: {float, double}")
   1144     .Attr("Tidx: {int32, int64} = DT_INT32")
   1145     .SetShapeFn(SparseSegmentReductionShapeFn);
   1146 
   1147 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
   1148     .Input("data: T")
   1149     .Input("indices: Tidx")
   1150     .Input("segment_ids: int32")
   1151     .Input("num_segments: Tnumsegments")
   1152     .Output("output: T")
   1153     .Attr("T: {float, double}")
   1154     .Attr("Tidx: {int32, int64} = DT_INT32")
   1155     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1156     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1157 
   1158 REGISTER_OP("SparseSegmentSqrtNGrad")
   1159     .Input("grad: T")
   1160     .Input("indices: Tidx")
   1161     .Input("segment_ids: int32")
   1162     .Input("output_dim0: int32")
   1163     .Output("output: T")
   1164     .Attr("T: {float, double}")
   1165     .Attr("Tidx: {int32, int64} = DT_INT32")
   1166     .SetShapeFn(SparseSegmentReductionGradShapeFn);
   1167 
   1168 REGISTER_OP("All")
   1169     .Input("input: bool")
   1170     .Input("reduction_indices: Tidx")
   1171     .Output("output: bool")
   1172     .Attr("keep_dims: bool = false")
   1173     .Attr("Tidx: {int32, int64} = DT_INT32")
   1174     .SetShapeFn(shape_inference::ReductionShape);
   1175 
   1176 REGISTER_OP("Any")
   1177     .Input("input: bool")
   1178     .Input("reduction_indices: Tidx")
   1179     .Attr("keep_dims: bool = false")
   1180     .Output("output: bool")
   1181     .Attr("Tidx: {int32, int64} = DT_INT32")
   1182     .SetShapeFn(shape_inference::ReductionShape);
   1183 
   1184 // --------------------------------------------------------------------------
   1185 
   1186 namespace {
   1187 
   1188 template <typename T>
   1189 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
   1190                  const Tensor* delta_t, InferenceContext* const c) {
   1191   T start = start_t->scalar<T>()();
   1192   T limit = limit_t->scalar<T>()();
   1193   T delta = delta_t->scalar<T>()();
   1194   if (start > limit && delta > 0) {
   1195     return errors::InvalidArgument(
   1196         "Requires start <= limit when delta > 0: ", start, "/", limit);
   1197   }
   1198   if (start < limit && delta < 0) {
   1199     return errors::InvalidArgument(
   1200         "Requires start >= limit when delta < 0: ", start, "/", limit);
   1201   }
   1202   if (delta == 0) {
   1203     return errors::InvalidArgument("Requires delta != 0");
   1204   }
   1205 
   1206   int64 size =
   1207       (std::is_integral<T>::value
   1208            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
   1209            : std::ceil(std::abs((limit - start) / delta)));
   1210   c->set_output(0, c->Vector(size));
   1211   return Status::OK();
   1212 }
   1213 
   1214 }  // namespace
   1215 
   1216 REGISTER_OP("Range")
   1217     .Input("start: Tidx")
   1218     .Input("limit: Tidx")
   1219     .Input("delta: Tidx")
   1220     .Output("output: Tidx")
   1221     .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
   1222     .SetShapeFn([](InferenceContext* c) {
   1223       ShapeHandle unused;
   1224       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
   1225                                       " for 'start'");
   1226       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
   1227                                       " for 'limit'");
   1228       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
   1229                                       " for 'delta'");
   1230       const Tensor* start_t = c->input_tensor(0);
   1231       const Tensor* limit_t = c->input_tensor(1);
   1232       const Tensor* delta_t = c->input_tensor(2);
   1233       DataType dtype;
   1234       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
   1235       if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
   1236         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1237         return Status::OK();
   1238       }
   1239       if (dtype == DT_INT32) {
   1240         return RangeSize<int32>(start_t, limit_t, delta_t, c);
   1241       } else if (dtype == DT_INT64) {
   1242         return RangeSize<int64>(start_t, limit_t, delta_t, c);
   1243       } else if (dtype == DT_FLOAT) {
   1244         return RangeSize<float>(start_t, limit_t, delta_t, c);
   1245       } else {
   1246         return RangeSize<double>(start_t, limit_t, delta_t, c);
   1247       }
   1248       return Status::OK();
   1249     });
   1250 
   1251 REGISTER_OP("LinSpace")
   1252     .Input("start: T")
   1253     .Input("stop: T")
   1254     .Input("num: Tidx")
   1255     .Output("output: T")
   1256     .Attr("T: {bfloat16, float, double}")
   1257     .Attr("Tidx: {int32, int64} = DT_INT32")
   1258     .SetShapeFn([](InferenceContext* c) {
   1259       ShapeHandle unused;
   1260       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
   1261                                       " for 'start'");
   1262       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
   1263                                       " for 'stop'");
   1264       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
   1265                                       " for 'num'");
   1266       const Tensor* num_t = c->input_tensor(2);
   1267       if (num_t == nullptr) {
   1268         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1269         return Status::OK();
   1270       }
   1271 
   1272       int64 num;
   1273       if (num_t->dtype() == DT_INT32) {
   1274         num = num_t->scalar<int32>()();
   1275       } else {
   1276         num = num_t->scalar<int64>()();
   1277       }
   1278       if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
   1279       c->set_output(0, c->Vector(num));
   1280       return Status::OK();
   1281     });
   1282 
   1283 REGISTER_OP("Complex")
   1284     .Input("real: T")
   1285     .Input("imag: T")
   1286     .Output("out: Tout")
   1287     .Attr("T: {float, double} = DT_FLOAT")
   1288     .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
   1289     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
   1290 
   1291 REGISTER_OP("Real")
   1292     .Input("input: T")
   1293     .Output("output: Tout")
   1294     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1295     .Attr("Tout: {float, double} = DT_FLOAT")
   1296     .SetShapeFn(shape_inference::UnchangedShape);
   1297 
   1298 REGISTER_OP("Imag")
   1299     .Input("input: T")
   1300     .Output("output: Tout")
   1301     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1302     .Attr("Tout: {float, double} = DT_FLOAT")
   1303     .SetShapeFn(shape_inference::UnchangedShape);
   1304 
   1305 REGISTER_OP("Angle")
   1306     .Input("input: T")
   1307     .Output("output: Tout")
   1308     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1309     .Attr("Tout: {float, double} = DT_FLOAT")
   1310     .SetShapeFn(shape_inference::UnchangedShape);
   1311 
   1312 REGISTER_OP("Conj")
   1313     .Input("input: T")
   1314     .Output("output: T")
   1315     .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
   1316     .SetShapeFn(shape_inference::UnchangedShape);
   1317 
   1318 // --------------------------------------------------------------------------
   1319 
   1320 REGISTER_OP("Cross")
   1321     .Input("a: T")
   1322     .Input("b: T")
   1323     .Output("product: T")
   1324     .Attr("T: realnumbertype")
   1325     .SetShapeFn([](InferenceContext* c) {
   1326       ShapeHandle a_shape;
   1327       ShapeHandle b_shape;
   1328       // * Input rank >= 1.
   1329       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
   1330       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
   1331 
   1332       // * Both inputs have the same shape.
   1333       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
   1334 
   1335       // * input_shape[-1] == 3.
   1336       if (c->RankKnown(a_shape)) {
   1337         int rank = c->Rank(a_shape);
   1338         auto dim = c->Dim(a_shape, rank - 1);
   1339         TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
   1340       }
   1341       c->set_output(0, a_shape);
   1342       return Status::OK();
   1343     });
   1344 
   1345 // --------------------------------------------------------------------------
   1346 
   1347 REGISTER_OP("HistogramFixedWidth")
   1348     .Input("values: T")
   1349     .Input("value_range: T")
   1350     .Input("nbins: int32")
   1351     .Output("out: dtype")
   1352     .Attr("T: {int32, int64, float32, float64}")
   1353     .Attr("dtype: {int32, int64} = DT_INT32")
   1354     .SetShapeFn([](InferenceContext* c) {
   1355       const Tensor* nbins_input = c->input_tensor(2);
   1356       if (nbins_input != nullptr) {
   1357         int64 nbins;
   1358         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
   1359         c->set_output(0, c->Vector(nbins));
   1360       } else {
   1361         c->set_output(0, c->UnknownShapeOfRank(1));
   1362       }
   1363       return Status::OK();
   1364     });
   1365 
   1366 REGISTER_OP("Bincount")
   1367     .Input("arr: int32")
   1368     .Input("size: int32")
   1369     .Input("weights: T")
   1370     .Attr("T: {int32, int64, float32, float64}")
   1371     .Output("bins: T")
   1372     .SetShapeFn([](InferenceContext* c) {
   1373       c->set_output(0, c->UnknownShapeOfRank(1));
   1374       return Status::OK();
   1375     });
   1376 
   1377 REGISTER_OP("Cumsum")
   1378     .Input("x: T")
   1379     .Input("axis: Tidx")
   1380     .Attr("exclusive: bool = false")
   1381     .Attr("reverse: bool = false")
   1382     .Output("out: T")
   1383     .Attr("T: numbertype")
   1384     .Attr("Tidx: {int32, int64} = DT_INT32")
   1385     .SetShapeFn(shape_inference::UnchangedShape);
   1386 
   1387 REGISTER_OP("Cumprod")
   1388     .Input("x: T")
   1389     .Input("axis: Tidx")
   1390     .Attr("exclusive: bool = false")
   1391     .Attr("reverse: bool = false")
   1392     .Output("out: T")
   1393     .Attr("T: numbertype")
   1394     .Attr("Tidx: {int32, int64} = DT_INT32")
   1395     .SetShapeFn(shape_inference::UnchangedShape);
   1396 
   1397 REGISTER_OP("QuantizedMatMul")
   1398     .Input("a: T1")
   1399     .Input("b: T2")
   1400     .Input("min_a: float")
   1401     .Input("max_a: float")
   1402     .Input("min_b: float")
   1403     .Input("max_b: float")
   1404     .Output("out: Toutput")
   1405     .Output("min_out: float")
   1406     .Output("max_out: float")
   1407     .Attr("T1: quantizedtype")
   1408     .Attr("T2: quantizedtype")
   1409     .Attr("Toutput: quantizedtype = DT_QINT32")
   1410     .Attr("transpose_a: bool = false")
   1411     .Attr("transpose_b: bool = false")
   1412     .Attr("Tactivation: quantizedtype = DT_QUINT8")
   1413     .SetShapeFn([](InferenceContext* c) {
   1414       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
   1415       ShapeHandle unused;
   1416       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1417       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1418       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1419       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
   1420 
   1421       c->set_output(1, c->Scalar());
   1422       c->set_output(2, c->Scalar());
   1423       return Status::OK();
   1424     });
   1425 
   1426 REGISTER_OP("QuantizedMul")
   1427     .Input("x: T1")
   1428     .Input("y: T2")
   1429     .Input("min_x: float")
   1430     .Input("max_x: float")
   1431     .Input("min_y: float")
   1432     .Input("max_y: float")
   1433     .Output("z: Toutput")
   1434     .Output("min_z: float")
   1435     .Output("max_z: float")
   1436     .Attr("T1: quantizedtype")
   1437     .Attr("T2: quantizedtype")
   1438     .Attr("Toutput: quantizedtype = DT_QINT32")
   1439     .SetIsCommutative()
   1440     .SetShapeFn([](InferenceContext* c) {
   1441       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
   1442       c->set_output(1, c->Scalar());
   1443       c->set_output(2, c->Scalar());
   1444       return Status::OK();
   1445     });
   1446 
   1447 REGISTER_OP("QuantizedAdd")
   1448     .Input("x: T1")
   1449     .Input("y: T2")
   1450     .Input("min_x: float")
   1451     .Input("max_x: float")
   1452     .Input("min_y: float")
   1453     .Input("max_y: float")
   1454     .Output("z: Toutput")
   1455     .Output("min_z: float")
   1456     .Output("max_z: float")
   1457     .Attr("T1: quantizedtype")
   1458     .Attr("T2: quantizedtype")
   1459     .Attr("Toutput: quantizedtype = DT_QINT32")
   1460     .SetIsCommutative()
   1461     .SetShapeFn([](InferenceContext* c) {
   1462       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
   1463       c->set_output(1, c->Scalar());
   1464       c->set_output(2, c->Scalar());
   1465       return Status::OK();
   1466     });
   1467 
   1468 REGISTER_OP("QuantizeDownAndShrinkRange")
   1469     .Input("input: Tinput")
   1470     .Input("input_min: float")
   1471     .Input("input_max: float")
   1472     .Output("output: out_type")
   1473     .Output("output_min: float")
   1474     .Output("output_max: float")
   1475     .Attr("Tinput: quantizedtype")
   1476     .Attr("out_type: quantizedtype")
   1477     .SetShapeFn([](InferenceContext* c) {
   1478       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1479       ShapeHandle unused;
   1480       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1481       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1482       c->set_output(1, c->Scalar());
   1483       c->set_output(2, c->Scalar());
   1484       return Status::OK();
   1485     });
   1486 
   1487 REGISTER_OP("Requantize")
   1488     .Input("input: Tinput")
   1489     .Input("input_min: float")
   1490     .Input("input_max: float")
   1491     .Input("requested_output_min: float")
   1492     .Input("requested_output_max: float")
   1493     .Output("output: out_type")
   1494     .Output("output_min: float")
   1495     .Output("output_max: float")
   1496     .Attr("Tinput: quantizedtype")
   1497     .Attr("out_type: quantizedtype")
   1498     .SetShapeFn([](InferenceContext* c) {
   1499       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1500       ShapeHandle unused;
   1501       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1502       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1503       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1504       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1505       c->set_output(1, c->Scalar());
   1506       c->set_output(2, c->Scalar());
   1507       return Status::OK();
   1508     });
   1509 
   1510 REGISTER_OP("CompareAndBitpack")
   1511     .Input("input: T")
   1512     .Input("threshold: T")
   1513     .Output("output: uint8")
   1514     .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}")
   1515     .SetShapeFn([](InferenceContext* c) {
   1516       ShapeHandle input;
   1517       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
   1518       ShapeHandle unused;
   1519       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1520       ShapeHandle output = input;
   1521       if (c->RankKnown(input)) {
   1522         int rank = c->Rank(input);
   1523         auto inner_dim = c->Dim(input, rank - 1);
   1524         DimensionHandle inferred_dim;
   1525         TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8,
   1526                                      /* evenly_divisible */ true,
   1527                                      &inferred_dim));
   1528         TF_RETURN_IF_ERROR(
   1529             c->ReplaceDim(output, rank - 1, inferred_dim, &output));
   1530       }
   1531       c->set_output(0, output);
   1532 
   1533       return Status::OK();
   1534     });
   1535 
   1536 REGISTER_OP("RequantizationRange")
   1537     .Input("input: Tinput")
   1538     .Input("input_min: float")
   1539     .Input("input_max: float")
   1540     .Output("output_min: float")
   1541     .Output("output_max: float")
   1542     .Attr("Tinput: quantizedtype")
   1543     .SetShapeFn([](InferenceContext* c) {
   1544       ShapeHandle unused;
   1545       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1546       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1547       c->set_output(0, c->Scalar());
   1548       c->set_output(1, c->Scalar());
   1549       return Status::OK();
   1550     });
   1551 
   1552 // --------------------------------------------------------------------------
   1553 
   1554 REGISTER_OP("Bucketize")
   1555     .Input("input: T")
   1556     .Output("output: int32")
   1557     .Attr("T: {int32, int64, float, double}")
   1558     .Attr("boundaries: list(float)")
   1559     .SetShapeFn(shape_inference::UnchangedShape);
   1560 
   1561 #ifdef INTEL_MKL
   1562 REGISTER_OP("_MklAddN")
   1563     .Input("inputs: N * T")
   1564     .Input("mkl_input: N * uint8")
   1565     .Output("sum: T")
   1566     .Output("mkl_sum: uint8")
   1567     .Attr("N: int >= 1")
   1568     .Attr("T: numbertype")
   1569     .SetIsCommutative()
   1570     .SetIsAggregate()
   1571     .SetShapeFn([](InferenceContext* c) {
   1572       ShapeHandle cur = c->input(c->num_inputs() - 1);
   1573       for (int i = c->num_inputs() - 2; i >= 0; --i) {
   1574         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
   1575                                         "From merging shape ", i,
   1576                                         " with other shapes.");
   1577       }
   1578       c->set_output(0, cur);
   1579       return Status::OK();
   1580     })
   1581     .Doc(R"doc(
   1582 Add two input tensors element wise using mkl kernel sum.
   1583 inputs: Must all be the same size and shape.
   1584 )doc");
   1585 
   1586 #endif  // INTEL_MKL
   1587 
   1588 }  // namespace tensorflow
   1589