Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/numeric_op.h"
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/shape_inference.h"
     20 
     21 namespace tensorflow {
     22 
     23 using shape_inference::DimensionHandle;
     24 using shape_inference::InferenceContext;
     25 using shape_inference::ShapeHandle;
     26 
     27 REGISTER_OP("AddN")
     28     .Input("inputs: N * T")
     29     .Output("sum: T")
     30     .Attr("N: int >= 1")
     31     .Attr("T: {numbertype, variant}")
     32     .SetIsCommutative()
     33     .SetIsAggregate()
     34     .SetShapeFn([](InferenceContext* c) {
     35       ShapeHandle cur = c->input(c->num_inputs() - 1);
     36       for (int i = c->num_inputs() - 2; i >= 0; --i) {
     37         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
     38                                         "From merging shape ", i,
     39                                         " with other shapes.");
     40       }
     41       c->set_output(0, cur);
     42 
     43       DataType dtype;
     44       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
     45 
     46       if (dtype != DT_VARIANT) {
     47         // Exit early if not DT_VARIANT.
     48         return Status::OK();
     49       } else {
     50         // DT_VARIANT shape handle shape inference.  All sizes and dtypes must
     51         // be the same; all shapes must be compatible via Merge.
     52         std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
     53         auto* shapes_and_types =
     54             c->input_handle_shapes_and_types(c->num_inputs() - 1);
     55         if (shapes_and_types) {
     56           cur_shapes_and_types = *shapes_and_types;
     57         }
     58 
     59         for (int i = c->num_inputs() - 2; i >= 0; --i) {
     60           auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
     61           if (!shapes_and_types && shapes_and_types_i) {
     62             // TODO(ebrevdo): Find cases where this happens and fix their shape
     63             // inference.  If we are calling AddN on variant types, they should
     64             // all have consistent shape_and_type info.
     65             shapes_and_types = shapes_and_types_i;
     66           } else if (shapes_and_types && shapes_and_types_i) {
     67             if (shapes_and_types_i->size() != shapes_and_types->size()) {
     68               return errors::InvalidArgument(
     69                   "shapes_and_types[", i,
     70                   "].size() == ", shapes_and_types_i->size(),
     71                   " != shapes_and_types[0].size() == ",
     72                   shapes_and_types->size());
     73             }
     74             for (int j = 0; j < shapes_and_types->size(); ++j) {
     75               if (shapes_and_types->at(j).dtype !=
     76                   shapes_and_types_i->at(j).dtype) {
     77                 return errors::InvalidArgument(
     78                     "shapes_and_types[", i, "][", j, "].dtype() == ",
     79                     DataTypeString(shapes_and_types_i->at(j).dtype),
     80                     " != shapes_and_types[0][", j, "].dtype == ",
     81                     DataTypeString(shapes_and_types->at(j).dtype));
     82               }
     83               TF_RETURN_WITH_CONTEXT_IF_ERROR(
     84                   c->Merge(shapes_and_types_i->at(j).shape,
     85                            cur_shapes_and_types.at(j).shape,
     86                            &cur_shapes_and_types.at(j).shape),
     87                   "From merging shapes_and_types[", i, "][", j, "].shape with ",
     88                   "shapes_and_types[0][", j, "].shape");
     89             }
     90           }
     91         }
     92         if (shapes_and_types) {
     93           c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
     94         }
     95         return Status::OK();
     96       }
     97     });
     98 
     99 // --------------------------------------------------------------------------
    100 
    101 // Note that the following operator is just a placeholder and has no
    102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
    103 // this placeholder with a graph of operators that do have kernels.
    104 // The Python code that generates instances of this op is currently in
    105 // contrib/framework/python/ops/accumulate_n_v2.py
    106 REGISTER_OP("AccumulateNV2")
    107     .Input("inputs: N * T")
    108     .Output("sum: T")
    109     .Attr("N: int >= 1")
    110     .Attr("T: numbertype")
    111     .Attr("shape: shape")
    112     .SetIsCommutative()
    113     .SetIsAggregate()
    114     .SetShapeFn(shape_inference::ExplicitShape);
    115 
    116 // --------------------------------------------------------------------------
    117 
    118 REGISTER_OP("BatchMatMul")
    119     .Input("x: T")
    120     .Input("y: T")
    121     .Output("output: T")
    122     .Attr(
    123         "T: {bfloat16, half, float, double, int32, int64, complex64, "
    124         "complex128}")
    125     .Attr("adj_x: bool = false")
    126     .Attr("adj_y: bool = false")
    127     .SetShapeFn([](InferenceContext* c) {
    128       ShapeHandle a_shape;
    129       ShapeHandle b_shape;
    130       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
    131       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
    132 
    133       // Determine output rows and cols.
    134       bool adj_x;
    135       bool adj_y;
    136       TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
    137       TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
    138       DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
    139       DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
    140 
    141       // Batch dims match between inputs.
    142       ShapeHandle a_batch_dims;
    143       ShapeHandle b_batch_dims;
    144       ShapeHandle batch_dims;
    145       TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
    146       TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
    147       TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
    148 
    149       // Assert inner dims match.
    150       DimensionHandle unused;
    151       TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
    152                                   c->Dim(b_shape, adj_y ? -1 : -2), &unused));
    153 
    154       ShapeHandle out;
    155       TF_RETURN_IF_ERROR(c->Concatenate(
    156           batch_dims, c->Matrix(output_rows, output_cols), &out));
    157       c->set_output(0, out);
    158       return Status::OK();
    159     });
    160 
    161 // --------------------------------------------------------------------------
    162 // Casting Ops
    163 //
    164 // NOTE: Only a smaller number of types are supported by
    165 // Cast. The exact casting rule is TBD. The current
    166 // implementation uses C++ static cast rules for numeric
    167 // types, which may be changed in the future.
    168 REGISTER_OP("Cast")
    169     .Input("x: SrcT")
    170     .Output("y: DstT")
    171     .Attr("SrcT: type")
    172     .Attr("DstT: type")
    173     .Attr("Truncate: bool = false")
    174     .SetShapeFn(shape_inference::UnchangedShape);
    175 
    176 REGISTER_OP("_HostCast")
    177     .Input("x: SrcT")
    178     .Output("y: DstT")
    179     .Attr("SrcT: type")
    180     .Attr("DstT: type")
    181     .Attr("Truncate: bool = false")
    182     .SetShapeFn(shape_inference::UnchangedShape)
    183     .Doc(R"doc(
    184 Cast x of type SrcT to y of DstT.
    185 
    186 _HostCast requires its input and produces its output in host memory.
    187 )doc");
    188 
    189 // --------------------------------------------------------------------------
    190 
    191 REGISTER_OP("Abs")
    192     .Input("x: T")
    193     .Output("y: T")
    194     .Attr("T: {bfloat16, half, float, double, int32, int64}")
    195     .SetShapeFn(shape_inference::UnchangedShape);
    196 
    197 REGISTER_OP("ComplexAbs")
    198     .Input("x: T")
    199     .Output("y: Tout")
    200     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
    201     .Attr("Tout: {float, double} = DT_FLOAT")
    202     .SetShapeFn(shape_inference::UnchangedShape);
    203 
    204 // Declares cwise unary operations signature: 't -> 't
    205 #define UNARY()                                                          \
    206   Input("x: T")                                                          \
    207       .Output("y: T")                                                    \
    208       .Attr(                                                             \
    209           "T: {bfloat16, half, float, double, int32, int64, complex64, " \
    210           "complex128}")                                                 \
    211       .SetShapeFn(shape_inference::UnchangedShape)
    212 
    213 #define UNARY_REAL()                              \
    214   Input("x: T")                                   \
    215       .Output("y: T")                             \
    216       .Attr("T: {bfloat16, half, float, double}") \
    217       .SetShapeFn(shape_inference::UnchangedShape)
    218 
    219 #define UNARY_COMPLEX()                                                  \
    220   Input("x: T")                                                          \
    221       .Output("y: T")                                                    \
    222       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
    223       .SetShapeFn(shape_inference::UnchangedShape)
    224 
    225 #define UNARY_GRADIENT_COMPLEX()                                         \
    226   Input("y: T")                                                          \
    227       .Input("dy: T")                                                    \
    228       .Output("z: T")                                                    \
    229       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
    230       .SetShapeFn(shape_inference::UnchangedShape)
    231 
    232 REGISTER_OP("Neg").UNARY();
    233 
    234 REGISTER_OP("Inv").UNARY();
    235 
    236 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
    237 
    238 REGISTER_OP("Reciprocal").UNARY();
    239 
    240 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
    241 
    242 REGISTER_OP("Square").UNARY();
    243 
    244 REGISTER_OP("Sqrt").UNARY_COMPLEX();
    245 
    246 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
    247 
    248 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
    249 
    250 REGISTER_OP("Round").UNARY();
    251 
    252 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
    253 
    254 REGISTER_OP("Exp").UNARY_COMPLEX();
    255 
    256 REGISTER_OP("Expm1").UNARY_COMPLEX();
    257 
    258 REGISTER_OP("Log").UNARY_COMPLEX();
    259 
    260 REGISTER_OP("Log1p").UNARY_COMPLEX();
    261 
    262 REGISTER_OP("Sinh").UNARY_COMPLEX();
    263 
    264 REGISTER_OP("Cosh").UNARY_COMPLEX();
    265 
    266 REGISTER_OP("Tanh").UNARY_COMPLEX();
    267 
    268 REGISTER_OP("Asinh").UNARY_COMPLEX();
    269 
    270 REGISTER_OP("Acosh").UNARY_COMPLEX();
    271 
    272 REGISTER_OP("Atanh").UNARY_COMPLEX();
    273 
    274 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
    275 
    276 REGISTER_OP("Lgamma").UNARY_REAL();
    277 
    278 REGISTER_OP("Digamma").UNARY_REAL();
    279 
    280 REGISTER_OP("Erf").UNARY_REAL();
    281 
    282 REGISTER_OP("Erfc").UNARY_REAL();
    283 
    284 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
    285 
    286 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
    287 
    288 REGISTER_OP("Sin").UNARY_COMPLEX();
    289 
    290 REGISTER_OP("Cos").UNARY_COMPLEX();
    291 
    292 REGISTER_OP("Tan").UNARY();
    293 
    294 REGISTER_OP("Asin").UNARY();
    295 
    296 REGISTER_OP("Acos").UNARY();
    297 
    298 REGISTER_OP("Atan").UNARY();
    299 
    300 REGISTER_OP("BesselI0e").UNARY_REAL();
    301 
    302 REGISTER_OP("BesselI1e").UNARY_REAL();
    303 
    304 REGISTER_OP("_UnaryOpsComposition")
    305     .Input("x: T")
    306     .Output("y: T")
    307     .Attr("T: {float, half, double}")
    308     .Attr("op_names: list(string)")
    309     .SetShapeFn(shape_inference::UnchangedShape)
    310     .Doc(R"doc(
    311 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
    312 expected to create these operators.
    313 )doc");
    314 
    315 #undef UNARY
    316 #undef UNARY_REAL
    317 #undef UNARY_COMPLEX
    318 
    319 REGISTER_OP("IsNan")
    320     .Input("x: T")
    321     .Output("y: bool")
    322     .Attr("T: {bfloat16, half, float, double}")
    323     .SetShapeFn(shape_inference::UnchangedShape);
    324 
    325 REGISTER_OP("IsInf")
    326     .Input("x: T")
    327     .Output("y: bool")
    328     .Attr("T: {bfloat16, half, float, double}")
    329     .SetShapeFn(shape_inference::UnchangedShape);
    330 
    331 REGISTER_OP("IsFinite")
    332     .Input("x: T")
    333     .Output("y: bool")
    334     .Attr("T: {bfloat16, half, float, double}")
    335     .SetShapeFn(shape_inference::UnchangedShape);
    336 
    337 REGISTER_OP("Sign")
    338     .Input("x: T")
    339     .Output("y: T")
    340     .Attr(
    341         "T: {bfloat16, half, float, double, int32, int64, complex64, "
    342         "complex128}")
    343     .SetShapeFn(shape_inference::UnchangedShape);
    344 
    345 REGISTER_OP("Floor")
    346     .Input("x: T")
    347     .Output("y: T")
    348     .Attr("T: {bfloat16, half, float, double}")
    349     .SetShapeFn(shape_inference::UnchangedShape);
    350 
    351 REGISTER_OP("Ceil")
    352     .Input("x: T")
    353     .Output("y: T")
    354     .Attr("T: {bfloat16, half, float, double}")
    355     .SetShapeFn(shape_inference::UnchangedShape);
    356 
    357 REGISTER_OP("Rint")
    358     .Input("x: T")
    359     .Output("y: T")
    360     .Attr("T: {bfloat16, half, float, double}")
    361     .SetShapeFn(shape_inference::UnchangedShape);
    362 
    363 // Declares cwise binary operations signature: 't, 't -> 't.
    364 
    365 #define BINARY_MORE()                                                          \
    366   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
    367       "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
    368       "int64, complex64, complex128}")
    369 
    370 #define BINARY_FEWER()                                               \
    371   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
    372       "T: {bfloat16, half, float, double, int32, int64, complex64, " \
    373       "complex128}")
    374 
    375 REGISTER_OP("Add")
    376     .Input("x: T")
    377     .Input("y: T")
    378     .Output("z: T")
    379     .Attr(
    380         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
    381         "complex64, complex128, string}")
    382     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    383 
    384 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
    385 // use AddV2 (b/68646025).
    386 REGISTER_OP("AddV2")
    387     .Input("x: T")
    388     .Input("y: T")
    389     .Output("z: T")
    390     .Attr(
    391         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
    392         "complex64, complex128}")
    393     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    394     .SetIsAggregate()
    395     .SetIsCommutative();
    396 
    397 REGISTER_OP("_MklAdd")
    398     .Input("x: T")
    399     .Input("y: T")
    400     .Input("mkl_x: uint8")
    401     .Input("mkl_y: uint8")
    402     .Output("z: T")
    403     .Output("mkl_z: uint8")
    404     .Attr(
    405         "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
    406         "complex128, string}")
    407     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    408     .Doc(R"doc(
    409 Returns `x` + `y` element-wise.
    410 
    411 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
    412 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
    413 )doc");
    414 
    415 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
    416     shape_inference::BroadcastBinaryOpShapeFn);
    417 
    418 REGISTER_OP("_MklSub")
    419     .BINARY_FEWER()
    420     .Input("mkl_x: uint8")
    421     .Input("mkl_y: uint8")
    422     .Output("mkl_z: uint8")
    423     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    424     .Doc(R"doc(
    425 Returns x - y element-wise.
    426 
    427 *NOTE*: `Sub` supports broadcasting. More about broadcasting
    428 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    429 )doc");
    430 
    431 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
    432     shape_inference::BroadcastBinaryOpShapeFn);
    433 
    434 REGISTER_OP("MulNoNan")
    435     .Input("x: T")
    436     .Input("y: T")
    437     .Output("z: T")
    438     .Attr("T: {half, float, double, complex64, complex128}")
    439     .SetIsCommutative()
    440     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    441 
    442 REGISTER_OP("_MklMul")
    443     .BINARY_MORE()
    444     .Input("mkl_x: uint8")
    445     .Input("mkl_y: uint8")
    446     .Output("mkl_z: uint8")
    447     .SetIsCommutative()
    448     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    449     .Doc(R"doc(
    450 Returns x * y element-wise.
    451 
    452 *NOTE*: `Mul` supports broadcasting. More about broadcasting
    453 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    454 )doc");
    455 
    456 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
    457     shape_inference::BroadcastBinaryOpShapeFn);
    458 
    459 REGISTER_OP("DivNoNan")
    460     .Input("x: T")
    461     .Input("y: T")
    462     .Output("z: T")
    463     .Attr("T: {half, float, double, complex64, complex128}")
    464     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    465 
    466 REGISTER_OP("FloorDiv")
    467     .BINARY_MORE()
    468     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    469 
    470 REGISTER_OP("TruncateDiv")
    471     .BINARY_MORE()
    472     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    473 
    474 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
    475     shape_inference::BroadcastBinaryOpShapeFn);
    476 
    477 REGISTER_OP("SquaredDifference")
    478     .BINARY_FEWER()
    479     .SetIsCommutative()
    480     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    481 
    482 REGISTER_OP("_MklSquaredDifference")
    483     .BINARY_FEWER()
    484     .Input("mkl_x: uint8")
    485     .Input("mkl_y: uint8")
    486     .Output("mkl_z: uint8")
    487     .SetIsCommutative()
    488     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    489     .Doc(R"doc(
    490 Returns (x - y)(x - y) element-wise.
    491 
    492 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
    493 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    494 )doc");
    495 
    496 REGISTER_OP("Xlogy")
    497     .Input("x: T")
    498     .Input("y: T")
    499     .Output("z: T")
    500     .Attr("T: {half, float, double, complex64, complex128}")
    501     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    502 
    503 REGISTER_OP("Xdivy")
    504     .Input("x: T")
    505     .Input("y: T")
    506     .Output("z: T")
    507     .Attr("T: {half, float, double, complex64, complex128}")
    508     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    509 
    510 #undef BINARY_FEWER
    511 #undef BINARY_MORE
    512 
    513 REGISTER_OP("Maximum")
    514     .Input("x: T")
    515     .Input("y: T")
    516     .Output("z: T")
    517     .Attr("T: {bfloat16, half, float, double, int32, int64}")
    518     .SetIsCommutative()
    519     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    520 
    521 REGISTER_OP("_MklMaximum")
    522     .Input("x: T")
    523     .Input("y: T")
    524     .Input("mkl_x: uint8")
    525     .Input("mkl_y: uint8")
    526     .Output("z: T")
    527     .Output("mkl_z: uint8")
    528     .Attr("T: {half, float, double, int32, int64}")
    529     .SetIsCommutative()
    530     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    531     .Doc(R"doc(
    532 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
    533 
    534 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
    535 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
    536 )doc");
    537 
    538 REGISTER_OP("Minimum")
    539     .Input("x: T")
    540     .Input("y: T")
    541     .Output("z: T")
    542     .Attr("T: {bfloat16, half, float, double, int32, int64}")
    543     .SetIsCommutative()
    544     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    545 
    546 REGISTER_OP("Mod")
    547     .Input("x: T")
    548     .Input("y: T")
    549     .Output("z: T")
    550     .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
    551     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    552 
    553 REGISTER_OP("FloorMod")
    554     .Input("x: T")
    555     .Input("y: T")
    556     .Output("z: T")
    557     .Attr("T: {int32, int64, bfloat16, half, float, double}")
    558     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    559 
    560 REGISTER_OP("TruncateMod")
    561     .Input("x: T")
    562     .Input("y: T")
    563     .Output("z: T")
    564     .Attr("T: {int32, int64, bfloat16, half, float, double}")
    565     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    566 
    567 REGISTER_OP("Pow")
    568     .Input("x: T")
    569     .Input("y: T")
    570     .Output("z: T")
    571     .Attr(
    572         "T: {bfloat16, float, half, double, int32, int64, complex64, "
    573         "complex128}")
    574     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    575 
    576 REGISTER_OP("Igammac")
    577     .Input("a: T")
    578     .Input("x: T")
    579     .Output("z: T")
    580     .Attr("T: {float, double}")
    581     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    582 
    583 REGISTER_OP("Igamma")
    584     .Input("a: T")
    585     .Input("x: T")
    586     .Output("z: T")
    587     .Attr("T: {float, double}")
    588     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    589 
    590 REGISTER_OP("IgammaGradA")
    591     .Input("a: T")
    592     .Input("x: T")
    593     .Output("z: T")
    594     .Attr("T: {float, double}")
    595     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    596 
    597 REGISTER_OP("Zeta")
    598     .Input("x: T")
    599     .Input("q: T")
    600     .Output("z: T")
    601     .Attr("T: {float, double}")
    602     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    603 
    604 REGISTER_OP("Polygamma")
    605     .Input("a: T")
    606     .Input("x: T")
    607     .Output("z: T")
    608     .Attr("T: {float, double}")
    609     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    610 
    611 REGISTER_OP("Atan2")
    612     .Input("y: T")
    613     .Input("x: T")
    614     .Output("z: T")
    615     .Attr("T: {bfloat16, half, float, double}")
    616     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
    617 
    618 REGISTER_OP("Betainc")
    619     .Input("a: T")
    620     .Input("b: T")
    621     .Input("x: T")
    622     .Output("z: T")
    623     .Attr("T: {float, double}")
    624     .SetShapeFn([](InferenceContext* c) {
    625       const int num_inputs = 3;
    626       ShapeHandle output = c->UnknownShape();
    627       int num_scalars = 0;
    628       ShapeHandle some_non_scalar;
    629       for (int i = 0; i < num_inputs; ++i) {
    630         ShapeHandle in = c->input(i);
    631         if (!c->RankKnown(in)) {
    632           some_non_scalar = in;
    633           // An input with unknown rank could be either a scalar (to be
    634           // broadcast) or some other shape.
    635         } else if (c->Rank(in) == 0) {
    636           // Input is a scalar, it will be broadcast to the output shape.
    637           ++num_scalars;
    638         } else {
    639           TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
    640           some_non_scalar = output;
    641         }
    642       }
    643 
    644       if (num_scalars == num_inputs - 1) {
    645         // If all but one input is known to be a scalar, then output is the
    646         // remaining input.
    647         output = some_non_scalar;
    648       } else if (num_scalars == num_inputs) {
    649         // If all are scalars, output is scalar; pick the first one arbitrarily.
    650         output = c->input(0);
    651       }
    652 
    653       c->set_output(0, output);
    654       return Status::OK();
    655     });
    656 
    657 // --------------------------------------------------------------------------
    658 
    659 // Declares cwise binary comparison operations signature: 't, 't -> bool,
    660 // where 't has a natural total order.
    661 #define COMPARISON()             \
    662   Input("x: T")                  \
    663       .Input("y: T")             \
    664       .Output("z: bool")         \
    665       .Attr("T: realnumbertype") \
    666       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    667 
    668 REGISTER_OP("Less").COMPARISON();
    669 
    670 REGISTER_OP("LessEqual").COMPARISON();
    671 
    672 REGISTER_OP("Greater").COMPARISON();
    673 
    674 REGISTER_OP("GreaterEqual").COMPARISON();
    675 
    676 #undef COMPARISON
    677 
    678 // --------------------------------------------------------------------------
    679 
    680 #define EQUALITY_COMPARISON()                                              \
    681   Input("x: T")                                                            \
    682       .Input("y: T")                                                       \
    683       .Output("z: bool")                                                   \
    684       .SetIsCommutative()                                                  \
    685       .Attr(                                                               \
    686           "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
    687           "int64, complex64, quint8, qint8, qint32, string, bool, "        \
    688           "complex128}")                                                   \
    689       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    690 
    691 REGISTER_OP("Equal").EQUALITY_COMPARISON();
    692 
    693 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
    694 
    695 #undef EQUALITY_COMPARISON
    696 
    697 REGISTER_OP("ApproximateEqual")
    698     .Input("x: T")
    699     .Input("y: T")
    700     .Output("z: bool")
    701     .SetIsCommutative()
    702     .Attr("T: numbertype")
    703     .Attr("tolerance: float = 0.00001")
    704     .SetShapeFn([](InferenceContext* c) {
    705       // The inputs 'x' and 'y' must have the same shape.
    706       ShapeHandle data_x = c->input(0);
    707       ShapeHandle data_y = c->input(1);
    708       TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
    709       return shape_inference::UnchangedShape(c);
    710     });
    711 
    712 // --------------------------------------------------------------------------
    713 
    714 REGISTER_OP("LogicalNot")
    715     .Input("x: bool")
    716     .Output("y: bool")
    717     .SetShapeFn(shape_inference::UnchangedShape);
    718 
    719 #define BINARY_LOGICAL()  \
    720   Input("x: bool")        \
    721       .Input("y: bool")   \
    722       .Output("z: bool")  \
    723       .SetIsCommutative() \
    724       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
    725 
    726 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
    727 
    728 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
    729 
    730 #undef BINARY_LOGICAL
    731 
    732 // --------------------------------------------------------------------------
    733 
    734 REGISTER_OP("Select")
    735     .Input("condition: bool")
    736     .Input("t: T")
    737     .Input("e: T")
    738     .Output("output: T")
    739     .Attr("T: type")
    740     .SetShapeFn([](InferenceContext* c) {
    741       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
    742       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
    743       // Merge handle shape and dtype if applicable.
    744       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
    745         const auto size = handle_data_1->size();
    746         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
    747         if (size != handle_data_2->size()) {
    748           return errors::InvalidArgument(
    749               "Trying to merge handles pointing to different numbers of "
    750               "tensors.");
    751         }
    752 
    753         for (int i = 0; i < size; ++i) {
    754           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
    755           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
    756           if (s1.dtype != s2.dtype) {
    757             // TODO(apassos) resolve this in the manner of b/32476923
    758             return errors::InvalidArgument(
    759                 "Trying to merge handles pointing to different dtypes.");
    760           }
    761           merged_handle_data[i].dtype = s1.dtype;
    762           TF_RETURN_IF_ERROR(
    763               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
    764         }
    765 
    766         c->set_output_handle_shapes_and_types(0, merged_handle_data);
    767       }
    768 
    769       // The inputs 'then' and 'else' must have the same shape.
    770       ShapeHandle data = c->input(1);
    771       ShapeHandle other = c->input(2);
    772       TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
    773 
    774       // The input 'cond' must either have the same shape as 'then' and
    775       // 'else', or be a vector if 'then' and 'else' are at least vectors.
    776       ShapeHandle cond = c->input(0);
    777 
    778       if (!c->RankKnown(cond) || !c->RankKnown(data)) {
    779         c->set_output(0, data);
    780         return Status::OK();
    781       }
    782 
    783       // rank of shape and data is known.
    784 
    785       const int32 cond_rank = c->Rank(cond);
    786       const int32 data_rank = c->Rank(data);
    787 
    788       if (cond_rank == 0) {
    789         // The rank of 'cond' is a scalar.
    790         // t and e can have any shape.
    791         c->set_output(0, data);
    792         return Status::OK();
    793       }
    794 
    795       if (cond_rank != 1) {
    796         // If 'cond' is not a vector, and not a scalar,
    797         // then shape must match 'then' and 'else'
    798         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
    799         c->set_output(0, data);
    800         return Status::OK();
    801       }
    802 
    803       if (data_rank == 0) {
    804         // if 'then' and 'else' are scalar also the cond must be
    805         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
    806         c->set_output(0, data);
    807         return Status::OK();
    808       }
    809 
    810       if (cond_rank == 1) {
    811         // if the cond is a vector and the 'then' is not a scalar,
    812         // the first dimension of 'then' and 'else'
    813         TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
    814         c->set_output(0, data);
    815         return Status::OK();
    816       }
    817 
    818       c->set_output(0, data);
    819 
    820       return Status::OK();
    821     });
    822 
    823 // --------------------------------------------------------------------------
    824 
    825 REGISTER_OP("MatMul")
    826     .Input("a: T")
    827     .Input("b: T")
    828     .Output("product: T")
    829     .Attr("transpose_a: bool = false")
    830     .Attr("transpose_b: bool = false")
    831     .Attr(
    832         "T: {bfloat16, half, float, double, int32, int64, complex64, "
    833         "complex128}")
    834     .SetShapeFn(shape_inference::MatMulShape);
    835 
    836 REGISTER_OP("SparseMatMul")
    837     .Input("a: Ta")
    838     .Input("b: Tb")
    839     .Output("product: float")
    840     .Attr("transpose_a: bool = false")
    841     .Attr("transpose_b: bool = false")
    842     .Attr("a_is_sparse: bool = false")
    843     .Attr("b_is_sparse: bool = false")
    844     .Attr("Ta: {float, bfloat16} = DT_FLOAT")
    845     .Attr("Tb: {float, bfloat16} = DT_FLOAT")
    846     .SetShapeFn(shape_inference::MatMulShape);
    847 
    848 // --------------------------------------------------------------------------
    849 
    850 // For operations where the output is a reduction function along some
    851 // dimensions of the input.
    852 REGISTER_OP("Sum")
    853     .Input("input: T")
    854     .Input("reduction_indices: Tidx")
    855     .Output("output: T")
    856     .Attr("keep_dims: bool = false")
    857     .Attr("T: numbertype")
    858     .Attr("Tidx: {int32, int64} = DT_INT32")
    859     .SetShapeFn(shape_inference::ReductionShape);
    860 
    861 REGISTER_OP("EuclideanNorm")
    862     .Input("input: T")
    863     .Input("reduction_indices: Tidx")
    864     .Output("output: T")
    865     .Attr("keep_dims: bool = false")
    866     .Attr("T: numbertype")
    867     .Attr("Tidx: {int32, int64} = DT_INT32")
    868     .SetShapeFn(shape_inference::ReductionShape);
    869 
    870 REGISTER_OP("Mean")
    871     .Input("input: T")
    872     .Input("reduction_indices: Tidx")
    873     .Output("output: T")
    874     .Attr("keep_dims: bool = false")
    875     .Attr("T: numbertype")
    876     .Attr("Tidx: {int32, int64} = DT_INT32")
    877     .SetShapeFn(shape_inference::ReductionShape);
    878 
    879 REGISTER_OP("Prod")
    880     .Input("input: T")
    881     .Input("reduction_indices: Tidx")
    882     .Output("output: T")
    883     .Attr("keep_dims: bool = false")
    884     .Attr("T: numbertype")
    885     .Attr("Tidx: {int32, int64} = DT_INT32")
    886     .SetShapeFn(shape_inference::ReductionShape);
    887 
    888 REGISTER_OP("Min")
    889     .Input("input: T")
    890     .Input("reduction_indices: Tidx")
    891     .Output("output: T")
    892     .Attr("keep_dims: bool = false")
    893     .Attr("T: numbertype")
    894     .Attr("Tidx: {int32, int64} = DT_INT32")
    895     .SetShapeFn(shape_inference::ReductionShape);
    896 
    897 REGISTER_OP("Max")
    898     .Input("input: T")
    899     .Input("reduction_indices: Tidx")
    900     .Output("output: T")
    901     .Attr("keep_dims: bool = false")
    902     .Attr("T: numbertype")
    903     .Attr("Tidx: {int32, int64} = DT_INT32")
    904     .SetShapeFn(shape_inference::ReductionShape);
    905 
    906 namespace {
    907 
    908 Status ArgOpShape(shape_inference::InferenceContext* c) {
    909   ShapeHandle dimension_shape;
    910   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
    911 
    912   ShapeHandle input_shape = c->input(0);
    913   if (!c->RankKnown(input_shape)) {
    914     return shape_inference::UnknownShape(c);
    915   }
    916 
    917   const int32 input_rank = c->Rank(input_shape);
    918   if (input_rank <= 1) {
    919     // Reducing a scalar/vector must return a scalar.
    920     return shape_inference::ScalarShape(c);
    921   }
    922 
    923   const Tensor* dim_t = c->input_tensor(1);
    924   if (dim_t == nullptr) {
    925     // We don't know the value of the dimension, but we
    926     // know the rank of the input, so return the correct
    927     // rank with unknown dimensions.
    928     std::vector<DimensionHandle> dims(input_rank - 1);
    929     for (int i = 0; i < dims.size(); ++i) {
    930       dims[i] = c->UnknownDim();
    931     }
    932 
    933     c->set_output(0, c->MakeShape(dims));
    934     return Status::OK();
    935   }
    936 
    937   int64 dimension_val;
    938   if (dim_t->dtype() == DT_INT32) {
    939     dimension_val = dim_t->scalar<int32>()();
    940   } else {
    941     dimension_val = dim_t->scalar<int64>()();
    942   }
    943 
    944   int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
    945   if (axis < 0 || axis >= input_rank) {
    946     return errors::InvalidArgument(
    947         "Dimension (", dimension_val, ") must be in the range [", -input_rank,
    948         ", ", input_rank, "), where ", input_rank,
    949         " is the number of dimensions in the input.");
    950   }
    951 
    952   // Return the input shape without the dimension being reduced.
    953   std::vector<DimensionHandle> dims;
    954   for (int i = 0; i < input_rank; ++i) {
    955     if (axis != i) {
    956       dims.emplace_back(c->Dim(input_shape, i));
    957     }
    958   }
    959   c->set_output(0, c->MakeShape(dims));
    960   return Status::OK();
    961 }
    962 
    963 }  // namespace
    964 
    965 REGISTER_OP("ArgMax")
    966     .Input("input: T")
    967     .Input("dimension: Tidx")
    968     .Output("output: output_type")
    969     .Attr("T: numbertype")
    970     .Attr("Tidx: {int32, int64} = DT_INT32")
    971     .Attr("output_type: {int32, int64} = DT_INT64")
    972     .SetShapeFn(ArgOpShape);
    973 
    974 REGISTER_OP("ArgMin")
    975     .Input("input: T")
    976     .Input("dimension: Tidx")
    977     .Output("output: output_type")
    978     .Attr("T: numbertype")
    979     .Attr("Tidx: {int32, int64} = DT_INT32")
    980     .Attr("output_type: {int32, int64} = DT_INT64")
    981     .SetShapeFn(ArgOpShape);
    982 
    983 namespace {
    984 
    985 Status SegmentReductionShapeFn(InferenceContext* c) {
    986   ShapeHandle data_shape;
    987   ShapeHandle segment_ids_shape;
    988   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
    989   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
    990 
    991   ShapeHandle subshape;
    992   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
    993 
    994   ShapeHandle out;
    995   TF_RETURN_IF_ERROR(
    996       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
    997   c->set_output(0, out);
    998   return Status::OK();
    999 }
   1000 
   1001 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
   1002   ShapeHandle data_shape;
   1003   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
   1004 
   1005   ShapeHandle indices_shape;
   1006   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
   1007 
   1008   ShapeHandle segment_ids_shape;
   1009   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
   1010 
   1011   // indices and segment_ids should merge cleanly.
   1012   ShapeHandle unused;
   1013   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
   1014 
   1015   ShapeHandle subshape;
   1016   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
   1017 
   1018   ShapeHandle out;
   1019   TF_RETURN_IF_ERROR(
   1020       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
   1021   c->set_output(0, out);
   1022   return Status::OK();
   1023 }
   1024 
   1025 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
   1026   ShapeHandle data_shape;
   1027   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
   1028 
   1029   ShapeHandle indices_shape;
   1030   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
   1031 
   1032   // indices and segment_ids should merge cleanly.
   1033   ShapeHandle unused;
   1034   TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
   1035 
   1036   // output_dim0 should be a scalar
   1037   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1038 
   1039   ShapeHandle subshape;
   1040   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
   1041 
   1042   const Tensor* dim0 = c->input_tensor(3);
   1043   ShapeHandle dim0_shape;
   1044   if (dim0 == nullptr) {
   1045     // We don't have the value at inference time, so the output
   1046     // shape is unknown.
   1047     dim0_shape = c->Vector(InferenceContext::kUnknownDim);
   1048   } else {
   1049     auto dim0_value = dim0->scalar<int32>()();
   1050     if (dim0_value < 0) {
   1051       return errors::InvalidArgument(
   1052           "Cannot specify a negative value for output_dim0");
   1053     }
   1054     dim0_shape = c->Vector(dim0_value);
   1055   }
   1056 
   1057   ShapeHandle out;
   1058   TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
   1059   c->set_output(0, out);
   1060   return Status::OK();
   1061 }
   1062 
   1063 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
   1064   ShapeHandle data_shape;
   1065   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
   1066 
   1067   ShapeHandle indices_shape;
   1068   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
   1069 
   1070   ShapeHandle segment_ids_shape;
   1071   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
   1072 
   1073   ShapeHandle num_segments_shape;
   1074   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
   1075 
   1076   // indices and segment_ids should merge cleanly.
   1077   ShapeHandle unused;
   1078   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
   1079 
   1080   ShapeHandle subshape;
   1081   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
   1082 
   1083   ShapeHandle out;
   1084   const Tensor* dim0 = c->input_tensor(3);
   1085   if (dim0 == nullptr) {
   1086     // We don't have the value at inference time, so the output
   1087     // shape is unknown.
   1088     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
   1089                                       subshape, &out));
   1090   } else {
   1091     auto dim0_value = dim0->scalar<int32>()();
   1092     if (dim0_value < 0) {
   1093       return errors::InvalidArgument(
   1094           "Cannot specify a negative value for num_segments");
   1095     }
   1096     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
   1097   }
   1098   c->set_output(0, out);
   1099   return Status::OK();
   1100 }
   1101 
   1102 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
   1103   ShapeHandle s_data = c->input(0);
   1104   ShapeHandle s_segment_ids = c->input(1);
   1105   ShapeHandle s_num_segments = c->input(2);
   1106   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
   1107 
   1108   ShapeHandle out;
   1109 
   1110   // Leading dimensions of data must be compatible with dimensions of
   1111   // <s_segment_ids>.
   1112   if (c->RankKnown(s_segment_ids)) {
   1113     TF_RETURN_IF_ERROR(
   1114         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
   1115 
   1116     // Get the value of the num_segments input tensor.
   1117     DimensionHandle num_segments_dim;
   1118     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
   1119 
   1120     // Output is {segment_id_rank} + s_data[segment_id_rank:].
   1121     ShapeHandle s_data_suffix;
   1122     TF_RETURN_IF_ERROR(
   1123         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
   1124     TF_RETURN_IF_ERROR(
   1125         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
   1126   } else {
   1127     out = c->UnknownShape();
   1128   }
   1129   c->set_output(0, out);
   1130   return Status::OK();
   1131 }
   1132 }  // namespace
   1133 
   1134 REGISTER_OP("SegmentSum")
   1135     .Input("data: T")
   1136     .Input("segment_ids: Tindices")
   1137     .Output("output: T")
   1138     .Attr("T: numbertype")
   1139     .Attr("Tindices: {int32,int64}")
   1140     .SetShapeFn(SegmentReductionShapeFn);
   1141 
   1142 REGISTER_OP("SegmentMean")
   1143     .Input("data: T")
   1144     .Input("segment_ids: Tindices")
   1145     .Output("output: T")
   1146     .Attr("T: numbertype")
   1147     .Attr("Tindices: {int32,int64}")
   1148     .SetShapeFn(SegmentReductionShapeFn);
   1149 
   1150 REGISTER_OP("SegmentProd")
   1151     .Input("data: T")
   1152     .Input("segment_ids: Tindices")
   1153     .Output("output: T")
   1154     .Attr("T: numbertype")
   1155     .Attr("Tindices: {int32,int64}")
   1156     .SetShapeFn(SegmentReductionShapeFn);
   1157 
   1158 REGISTER_OP("SegmentMin")
   1159     .Input("data: T")
   1160     .Input("segment_ids: Tindices")
   1161     .Output("output: T")
   1162     .Attr("T: realnumbertype")
   1163     .Attr("Tindices: {int32,int64}")
   1164     .SetShapeFn(SegmentReductionShapeFn);
   1165 
   1166 REGISTER_OP("SegmentMax")
   1167     .Input("data: T")
   1168     .Input("segment_ids: Tindices")
   1169     .Output("output: T")
   1170     .Attr("T: realnumbertype")
   1171     .Attr("Tindices: {int32,int64}")
   1172     .SetShapeFn(SegmentReductionShapeFn);
   1173 
   1174 REGISTER_OP("UnsortedSegmentSum")
   1175     .Input("data: T")
   1176     .Input("segment_ids: Tindices")
   1177     .Input("num_segments: Tnumsegments")
   1178     .Output("output: T")
   1179     .Attr("T: numbertype")
   1180     .Attr("Tindices: {int32,int64}")
   1181     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1182     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1183 
   1184 REGISTER_OP("UnsortedSegmentMax")
   1185     .Input("data: T")
   1186     .Input("segment_ids: Tindices")
   1187     .Input("num_segments: Tnumsegments")
   1188     .Output("output: T")
   1189     .Attr("T: realnumbertype")
   1190     .Attr("Tindices: {int32,int64}")
   1191     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1192     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1193 
   1194 REGISTER_OP("UnsortedSegmentMin")
   1195     .Input("data: T")
   1196     .Input("segment_ids: Tindices")
   1197     .Input("num_segments: Tnumsegments")
   1198     .Output("output: T")
   1199     .Attr("T: realnumbertype")
   1200     .Attr("Tindices: {int32,int64}")
   1201     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1202     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1203 
   1204 REGISTER_OP("UnsortedSegmentProd")
   1205     .Input("data: T")
   1206     .Input("segment_ids: Tindices")
   1207     .Input("num_segments: Tnumsegments")
   1208     .Output("output: T")
   1209     .Attr("T: numbertype")
   1210     .Attr("Tindices: {int32,int64}")
   1211     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1212     .SetShapeFn(UnsortedSegmentReductionShapeFn);
   1213 
   1214 REGISTER_OP("SparseSegmentSum")
   1215     .Input("data: T")
   1216     .Input("indices: Tidx")
   1217     .Input("segment_ids: int32")
   1218     .Output("output: T")
   1219     .Attr("T: realnumbertype")
   1220     .Attr("Tidx: {int32, int64} = DT_INT32")
   1221     .SetShapeFn(SparseSegmentReductionShapeFn);
   1222 
   1223 REGISTER_OP("SparseSegmentSumWithNumSegments")
   1224     .Input("data: T")
   1225     .Input("indices: Tidx")
   1226     .Input("segment_ids: int32")
   1227     .Input("num_segments: Tnumsegments")
   1228     .Output("output: T")
   1229     .Attr("T: realnumbertype")
   1230     .Attr("Tidx: {int32, int64} = DT_INT32")
   1231     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1232     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1233 
   1234 REGISTER_OP("SparseSegmentMean")
   1235     .Input("data: T")
   1236     .Input("indices: Tidx")
   1237     .Input("segment_ids: int32")
   1238     .Output("output: T")
   1239     .Attr("T: {float, double}")
   1240     .Attr("Tidx: {int32, int64} = DT_INT32")
   1241     .SetShapeFn(SparseSegmentReductionShapeFn);
   1242 
   1243 REGISTER_OP("SparseSegmentMeanWithNumSegments")
   1244     .Input("data: T")
   1245     .Input("indices: Tidx")
   1246     .Input("segment_ids: int32")
   1247     .Input("num_segments: Tnumsegments")
   1248     .Output("output: T")
   1249     .Attr("T: {float, double}")
   1250     .Attr("Tidx: {int32, int64} = DT_INT32")
   1251     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1252     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1253 
   1254 REGISTER_OP("SparseSegmentMeanGrad")
   1255     .Input("grad: T")
   1256     .Input("indices: Tidx")
   1257     .Input("segment_ids: int32")
   1258     .Input("output_dim0: int32")
   1259     .Output("output: T")
   1260     .Attr("T: {float, double}")
   1261     .Attr("Tidx: {int32, int64} = DT_INT32")
   1262     .SetShapeFn(SparseSegmentReductionGradShapeFn);
   1263 
   1264 REGISTER_OP("SparseSegmentSqrtN")
   1265     .Input("data: T")
   1266     .Input("indices: Tidx")
   1267     .Input("segment_ids: int32")
   1268     .Output("output: T")
   1269     .Attr("T: {float, double}")
   1270     .Attr("Tidx: {int32, int64} = DT_INT32")
   1271     .SetShapeFn(SparseSegmentReductionShapeFn);
   1272 
   1273 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
   1274     .Input("data: T")
   1275     .Input("indices: Tidx")
   1276     .Input("segment_ids: int32")
   1277     .Input("num_segments: Tnumsegments")
   1278     .Output("output: T")
   1279     .Attr("T: {float, double}")
   1280     .Attr("Tidx: {int32, int64} = DT_INT32")
   1281     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
   1282     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
   1283 
   1284 REGISTER_OP("SparseSegmentSqrtNGrad")
   1285     .Input("grad: T")
   1286     .Input("indices: Tidx")
   1287     .Input("segment_ids: int32")
   1288     .Input("output_dim0: int32")
   1289     .Output("output: T")
   1290     .Attr("T: {float, double}")
   1291     .Attr("Tidx: {int32, int64} = DT_INT32")
   1292     .SetShapeFn(SparseSegmentReductionGradShapeFn);
   1293 
   1294 REGISTER_OP("All")
   1295     .Input("input: bool")
   1296     .Input("reduction_indices: Tidx")
   1297     .Output("output: bool")
   1298     .Attr("keep_dims: bool = false")
   1299     .Attr("Tidx: {int32, int64} = DT_INT32")
   1300     .SetShapeFn(shape_inference::ReductionShape);
   1301 
   1302 REGISTER_OP("Any")
   1303     .Input("input: bool")
   1304     .Input("reduction_indices: Tidx")
   1305     .Attr("keep_dims: bool = false")
   1306     .Output("output: bool")
   1307     .Attr("Tidx: {int32, int64} = DT_INT32")
   1308     .SetShapeFn(shape_inference::ReductionShape);
   1309 
   1310 // --------------------------------------------------------------------------
   1311 
   1312 namespace {
   1313 
   1314 template <typename T>
   1315 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
   1316                  const Tensor* delta_t, InferenceContext* const c) {
   1317   T start = start_t->scalar<T>()();
   1318   T limit = limit_t->scalar<T>()();
   1319   T delta = delta_t->scalar<T>()();
   1320   if (start > limit && delta > 0) {
   1321     return errors::InvalidArgument(
   1322         "Requires start <= limit when delta > 0: ", start, "/", limit);
   1323   }
   1324   if (start < limit && delta < 0) {
   1325     return errors::InvalidArgument(
   1326         "Requires start >= limit when delta < 0: ", start, "/", limit);
   1327   }
   1328   if (delta == 0) {
   1329     return errors::InvalidArgument("Requires delta != 0");
   1330   }
   1331 
   1332   int64 size =
   1333       (std::is_integral<T>::value
   1334            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
   1335            : std::ceil(std::abs((limit - start) / delta)));
   1336   c->set_output(0, c->Vector(size));
   1337   return Status::OK();
   1338 }
   1339 
   1340 }  // namespace
   1341 
   1342 REGISTER_OP("Range")
   1343     .Input("start: Tidx")
   1344     .Input("limit: Tidx")
   1345     .Input("delta: Tidx")
   1346     .Output("output: Tidx")
   1347     .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
   1348     .SetShapeFn([](InferenceContext* c) {
   1349       ShapeHandle unused;
   1350       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
   1351                                       " for 'start'");
   1352       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
   1353                                       " for 'limit'");
   1354       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
   1355                                       " for 'delta'");
   1356       const Tensor* start_t = c->input_tensor(0);
   1357       const Tensor* limit_t = c->input_tensor(1);
   1358       const Tensor* delta_t = c->input_tensor(2);
   1359       DataType dtype;
   1360       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
   1361       if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
   1362         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1363         return Status::OK();
   1364       }
   1365       if (dtype == DT_INT32) {
   1366         return RangeSize<int32>(start_t, limit_t, delta_t, c);
   1367       } else if (dtype == DT_INT64) {
   1368         return RangeSize<int64>(start_t, limit_t, delta_t, c);
   1369       } else if (dtype == DT_FLOAT) {
   1370         return RangeSize<float>(start_t, limit_t, delta_t, c);
   1371       } else {
   1372         return RangeSize<double>(start_t, limit_t, delta_t, c);
   1373       }
   1374       return Status::OK();
   1375     });
   1376 
   1377 REGISTER_OP("LinSpace")
   1378     .Input("start: T")
   1379     .Input("stop: T")
   1380     .Input("num: Tidx")
   1381     .Output("output: T")
   1382     .Attr("T: {bfloat16, float, double}")
   1383     .Attr("Tidx: {int32, int64} = DT_INT32")
   1384     .SetShapeFn([](InferenceContext* c) {
   1385       ShapeHandle unused;
   1386       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
   1387                                       " for 'start'");
   1388       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
   1389                                       " for 'stop'");
   1390       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
   1391                                       " for 'num'");
   1392       const Tensor* num_t = c->input_tensor(2);
   1393       if (num_t == nullptr) {
   1394         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
   1395         return Status::OK();
   1396       }
   1397 
   1398       int64 num;
   1399       if (num_t->dtype() == DT_INT32) {
   1400         num = num_t->scalar<int32>()();
   1401       } else {
   1402         num = num_t->scalar<int64>()();
   1403       }
   1404       if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
   1405       c->set_output(0, c->Vector(num));
   1406       return Status::OK();
   1407     });
   1408 
   1409 REGISTER_OP("Complex")
   1410     .Input("real: T")
   1411     .Input("imag: T")
   1412     .Output("out: Tout")
   1413     .Attr("T: {float, double} = DT_FLOAT")
   1414     .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
   1415     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
   1416 
   1417 REGISTER_OP("Real")
   1418     .Input("input: T")
   1419     .Output("output: Tout")
   1420     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1421     .Attr("Tout: {float, double} = DT_FLOAT")
   1422     .SetShapeFn(shape_inference::UnchangedShape);
   1423 
   1424 REGISTER_OP("Imag")
   1425     .Input("input: T")
   1426     .Output("output: Tout")
   1427     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1428     .Attr("Tout: {float, double} = DT_FLOAT")
   1429     .SetShapeFn(shape_inference::UnchangedShape);
   1430 
   1431 REGISTER_OP("Angle")
   1432     .Input("input: T")
   1433     .Output("output: Tout")
   1434     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
   1435     .Attr("Tout: {float, double} = DT_FLOAT")
   1436     .SetShapeFn(shape_inference::UnchangedShape);
   1437 
   1438 REGISTER_OP("Conj")
   1439     .Input("input: T")
   1440     .Output("output: T")
   1441     .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
   1442     .SetShapeFn([](InferenceContext* c) {
   1443       c->set_output(0, c->input(0));
   1444       auto* handle_data = c->input_handle_shapes_and_types(0);
   1445       if (handle_data != nullptr) {
   1446         c->set_output_handle_shapes_and_types(0, *handle_data);
   1447       }
   1448       return Status::OK();
   1449     });
   1450 
   1451 // --------------------------------------------------------------------------
   1452 
   1453 REGISTER_OP("Cross")
   1454     .Input("a: T")
   1455     .Input("b: T")
   1456     .Output("product: T")
   1457     .Attr("T: realnumbertype")
   1458     .SetShapeFn([](InferenceContext* c) {
   1459       ShapeHandle a_shape;
   1460       ShapeHandle b_shape;
   1461       // * Input rank >= 1.
   1462       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
   1463       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
   1464 
   1465       // * Both inputs have the same shape.
   1466       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
   1467 
   1468       // * input_shape[-1] == 3.
   1469       if (c->RankKnown(a_shape)) {
   1470         int rank = c->Rank(a_shape);
   1471         auto dim = c->Dim(a_shape, rank - 1);
   1472         TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
   1473       }
   1474       c->set_output(0, a_shape);
   1475       return Status::OK();
   1476     });
   1477 
   1478 // --------------------------------------------------------------------------
   1479 
   1480 REGISTER_OP("HistogramFixedWidth")
   1481     .Input("values: T")
   1482     .Input("value_range: T")
   1483     .Input("nbins: int32")
   1484     .Output("out: dtype")
   1485     .Attr("T: {int32, int64, float32, float64}")
   1486     .Attr("dtype: {int32, int64} = DT_INT32")
   1487     .SetShapeFn([](InferenceContext* c) {
   1488       // value_range should be a vector.
   1489       ShapeHandle value_range_shape;
   1490       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
   1491       // value_range should have two elements.
   1492       DimensionHandle unused;
   1493       TF_RETURN_IF_ERROR(
   1494           c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
   1495       // nbins should be a scalar.
   1496       ShapeHandle nbins_shape;
   1497       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
   1498 
   1499       // If nbins is available, set the shape from nbins.
   1500       const Tensor* nbins_input = c->input_tensor(2);
   1501       if (nbins_input != nullptr) {
   1502         int64 nbins;
   1503         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
   1504         // nbins has to be positive.
   1505         if (nbins <= 0) {
   1506           return errors::InvalidArgument("Requires nbins > 0: ", nbins);
   1507         }
   1508         c->set_output(0, c->Vector(nbins));
   1509       } else {
   1510         c->set_output(0, c->UnknownShapeOfRank(1));
   1511       }
   1512       return Status::OK();
   1513     });
   1514 
   1515 REGISTER_OP("Bincount")
   1516     .Input("arr: int32")
   1517     .Input("size: int32")
   1518     .Input("weights: T")
   1519     .Attr("T: {int32, int64, float32, float64}")
   1520     .Output("bins: T")
   1521     .SetShapeFn([](InferenceContext* c) {
   1522       ShapeHandle unused;
   1523       // The input `size` must be a scalar.
   1524       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1525 
   1526       const Tensor* size_tensor = c->input_tensor(1);
   1527       if (size_tensor == nullptr) {
   1528         // Return unknown shape if size is not known.
   1529         c->set_output(0, c->UnknownShapeOfRank(1));
   1530         return Status::OK();
   1531       }
   1532 
   1533       // Return `[size]` shape if size is known.
   1534       int32 size_val = size_tensor->scalar<int32>()();
   1535       if (size_val < 0) {
   1536         return errors::InvalidArgument("size (", size_val,
   1537                                        ") must be non-negative");
   1538       }
   1539       c->set_output(0, c->MakeShape({size_val}));
   1540       return Status::OK();
   1541     });
   1542 
   1543 REGISTER_OP("Cumsum")
   1544     .Input("x: T")
   1545     .Input("axis: Tidx")
   1546     .Attr("exclusive: bool = false")
   1547     .Attr("reverse: bool = false")
   1548     .Output("out: T")
   1549     .Attr("T: numbertype")
   1550     .Attr("Tidx: {int32, int64} = DT_INT32")
   1551     .SetShapeFn(shape_inference::UnchangedShape);
   1552 
   1553 REGISTER_OP("Cumprod")
   1554     .Input("x: T")
   1555     .Input("axis: Tidx")
   1556     .Attr("exclusive: bool = false")
   1557     .Attr("reverse: bool = false")
   1558     .Output("out: T")
   1559     .Attr("T: numbertype")
   1560     .Attr("Tidx: {int32, int64} = DT_INT32")
   1561     .SetShapeFn(shape_inference::UnchangedShape);
   1562 
   1563 REGISTER_OP("QuantizedMatMul")
   1564     .Input("a: T1")
   1565     .Input("b: T2")
   1566     .Input("min_a: float")
   1567     .Input("max_a: float")
   1568     .Input("min_b: float")
   1569     .Input("max_b: float")
   1570     .Output("out: Toutput")
   1571     .Output("min_out: float")
   1572     .Output("max_out: float")
   1573     .Attr("T1: quantizedtype")
   1574     .Attr("T2: quantizedtype")
   1575     .Attr("Toutput: quantizedtype = DT_QINT32")
   1576     .Attr("transpose_a: bool = false")
   1577     .Attr("transpose_b: bool = false")
   1578     .Attr("Tactivation: quantizedtype = DT_QUINT8")
   1579     .SetShapeFn([](InferenceContext* c) {
   1580       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
   1581       ShapeHandle unused;
   1582       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1583       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1584       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1585       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
   1586 
   1587       c->set_output(1, c->Scalar());
   1588       c->set_output(2, c->Scalar());
   1589       return Status::OK();
   1590     });
   1591 
   1592 REGISTER_OP("QuantizedMul")
   1593     .Input("x: T1")
   1594     .Input("y: T2")
   1595     .Input("min_x: float")
   1596     .Input("max_x: float")
   1597     .Input("min_y: float")
   1598     .Input("max_y: float")
   1599     .Output("z: Toutput")
   1600     .Output("min_z: float")
   1601     .Output("max_z: float")
   1602     .Attr("T1: quantizedtype")
   1603     .Attr("T2: quantizedtype")
   1604     .Attr("Toutput: quantizedtype = DT_QINT32")
   1605     .SetIsCommutative()
   1606     .SetShapeFn([](InferenceContext* c) {
   1607       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
   1608       c->set_output(1, c->Scalar());
   1609       c->set_output(2, c->Scalar());
   1610       return Status::OK();
   1611     });
   1612 
   1613 REGISTER_OP("QuantizedAdd")
   1614     .Input("x: T1")
   1615     .Input("y: T2")
   1616     .Input("min_x: float")
   1617     .Input("max_x: float")
   1618     .Input("min_y: float")
   1619     .Input("max_y: float")
   1620     .Output("z: Toutput")
   1621     .Output("min_z: float")
   1622     .Output("max_z: float")
   1623     .Attr("T1: quantizedtype")
   1624     .Attr("T2: quantizedtype")
   1625     .Attr("Toutput: quantizedtype = DT_QINT32")
   1626     .SetIsCommutative()
   1627     .SetShapeFn([](InferenceContext* c) {
   1628       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
   1629       // min_x, max_x, min_y, max_y should be scalar.
   1630       ShapeHandle unused;
   1631       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1632       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1633       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1634       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
   1635 
   1636       c->set_output(1, c->Scalar());
   1637       c->set_output(2, c->Scalar());
   1638       return Status::OK();
   1639     });
   1640 
   1641 REGISTER_OP("QuantizeDownAndShrinkRange")
   1642     .Input("input: Tinput")
   1643     .Input("input_min: float")
   1644     .Input("input_max: float")
   1645     .Output("output: out_type")
   1646     .Output("output_min: float")
   1647     .Output("output_max: float")
   1648     .Attr("Tinput: quantizedtype")
   1649     .Attr("out_type: quantizedtype")
   1650     .SetShapeFn([](InferenceContext* c) {
   1651       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1652       ShapeHandle unused;
   1653       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1654       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1655       c->set_output(1, c->Scalar());
   1656       c->set_output(2, c->Scalar());
   1657       return Status::OK();
   1658     });
   1659 
   1660 REGISTER_OP("Requantize")
   1661     .Input("input: Tinput")
   1662     .Input("input_min: float")
   1663     .Input("input_max: float")
   1664     .Input("requested_output_min: float")
   1665     .Input("requested_output_max: float")
   1666     .Output("output: out_type")
   1667     .Output("output_min: float")
   1668     .Output("output_max: float")
   1669     .Attr("Tinput: quantizedtype")
   1670     .Attr("out_type: quantizedtype")
   1671     .SetShapeFn([](InferenceContext* c) {
   1672       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1673       ShapeHandle unused;
   1674       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1675       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1676       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1677       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1678       c->set_output(1, c->Scalar());
   1679       c->set_output(2, c->Scalar());
   1680       return Status::OK();
   1681     });
   1682 
   1683 REGISTER_OP("CompareAndBitpack")
   1684     .Input("input: T")
   1685     .Input("threshold: T")
   1686     .Output("output: uint8")
   1687     .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}")
   1688     .SetShapeFn([](InferenceContext* c) {
   1689       ShapeHandle input;
   1690       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
   1691       ShapeHandle unused;
   1692       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1693       ShapeHandle output = input;
   1694       if (c->RankKnown(input)) {
   1695         int rank = c->Rank(input);
   1696         auto inner_dim = c->Dim(input, rank - 1);
   1697         DimensionHandle inferred_dim;
   1698         TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8,
   1699                                      /* evenly_divisible */ true,
   1700                                      &inferred_dim));
   1701         TF_RETURN_IF_ERROR(
   1702             c->ReplaceDim(output, rank - 1, inferred_dim, &output));
   1703       }
   1704       c->set_output(0, output);
   1705 
   1706       return Status::OK();
   1707     });
   1708 
   1709 REGISTER_OP("RequantizationRange")
   1710     .Input("input: Tinput")
   1711     .Input("input_min: float")
   1712     .Input("input_max: float")
   1713     .Output("output_min: float")
   1714     .Output("output_max: float")
   1715     .Attr("Tinput: quantizedtype")
   1716     .SetShapeFn([](InferenceContext* c) {
   1717       ShapeHandle unused;
   1718       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1719       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1720       c->set_output(0, c->Scalar());
   1721       c->set_output(1, c->Scalar());
   1722       return Status::OK();
   1723     });
   1724 
   1725 // --------------------------------------------------------------------------
   1726 
   1727 REGISTER_OP("Bucketize")
   1728     .Input("input: T")
   1729     .Output("output: int32")
   1730     .Attr("T: {int32, int64, float, double}")
   1731     .Attr("boundaries: list(float)")
   1732     .SetShapeFn(shape_inference::UnchangedShape);
   1733 
   1734 REGISTER_OP("ClipByValue")
   1735     .Input("t: T")
   1736     .Input("clip_value_min: T")
   1737     .Input("clip_value_max: T")
   1738     .Output("output: T")
   1739     .Attr("T: numbertype")
   1740     .SetShapeFn(shape_inference::UnchangedShape);
   1741 
   1742 #ifdef INTEL_MKL
   1743 REGISTER_OP("_MklAddN")
   1744     .Input("inputs: N * T")
   1745     .Input("mkl_input: N * uint8")
   1746     .Output("sum: T")
   1747     .Output("mkl_sum: uint8")
   1748     .Attr("N: int >= 1")
   1749     .Attr("T: numbertype")
   1750     .SetIsCommutative()
   1751     .SetIsAggregate()
   1752     .SetShapeFn([](InferenceContext* c) {
   1753       ShapeHandle cur = c->input(c->num_inputs() - 1);
   1754       for (int i = c->num_inputs() - 2; i >= 0; --i) {
   1755         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
   1756                                         "From merging shape ", i,
   1757                                         " with other shapes.");
   1758       }
   1759       c->set_output(0, cur);
   1760       return Status::OK();
   1761     })
   1762     .Doc(R"doc(
   1763 Add two input tensors element wise using mkl kernel sum.
   1764 inputs: Must all be the same size and shape.
   1765 )doc");
   1766 
   1767 #endif  // INTEL_MKL
   1768 
   1769 REGISTER_OP("RequantizePerChannel")
   1770     .Input("input: T")
   1771     .Input("input_min: float")
   1772     .Input("input_max: float")
   1773     .Input("requested_output_min: float")
   1774     .Input("requested_output_max: float")
   1775     .Output("output: out_type")
   1776     .Output("output_min: float")
   1777     .Output("output_max: float")
   1778     .Attr("T: quantizedtype = DT_QINT32")
   1779     .Attr("out_type: quantizedtype = DT_QUINT8")
   1780     .SetShapeFn([](InferenceContext* c) {
   1781       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1782       ShapeHandle unused;
   1783       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
   1784       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
   1785       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1786       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1787       c->set_output(1, c->Scalar());
   1788       c->set_output(2, c->Scalar());
   1789       return Status::OK();
   1790     });
   1791 REGISTER_OP("RequantizationRangePerChannel")
   1792     .Input("input: T")
   1793     .Input("input_min: float")
   1794     .Input("input_max: float")
   1795     .Output("output_min: float")
   1796     .Output("output_max: float")
   1797     .Attr("T: quantizedtype = DT_QINT32")
   1798     .Attr("clip_value_max: float")
   1799     .SetShapeFn([](InferenceContext* c) {
   1800       ShapeHandle unused;
   1801       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
   1802       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
   1803       c->set_output(0, c->Scalar());
   1804       c->set_output(1, c->Scalar());
   1805       return Status::OK();
   1806     });
   1807 
   1808 REGISTER_OP("NextAfter")
   1809     .Attr("T: {float64, float32} = DT_FLOAT")
   1810     .Input("x1: T")
   1811     .Input("x2: T")
   1812     .Output("output: T")
   1813     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
   1814 
   1815 }  // namespace tensorflow
   1816