Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/numeric_op.h"
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/shape_inference.h"
     20 #include "tensorflow/core/util/mirror_pad_mode.h"
     21 #include "tensorflow/core/util/padding.h"
     22 #include "tensorflow/core/util/tensor_format.h"
     23 
     24 namespace tensorflow {
     25 
     26 using shape_inference::DimensionHandle;
     27 using shape_inference::InferenceContext;
     28 using shape_inference::ShapeHandle;
     29 
     30 namespace {
     31 
     32 Status FractionalPoolShapeFn(InferenceContext* c) {
     33   ShapeHandle input;
     34   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
     35 
     36   std::vector<float> pooling_ratio;
     37   TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
     38   if (pooling_ratio.size() != 4) {
     39     return errors::InvalidArgument(
     40         "pooling_ratio field must specify 4 dimensions");
     41   }
     42   std::vector<DimensionHandle> output_dims;
     43   for (int i = 0; i < 4; ++i) {
     44     DimensionHandle d = c->Dim(input, i);
     45     if (c->ValueKnown(d)) {
     46       // This must match the same logic in the kernel function in
     47       // core/kernels/fractional_max_pool_op.cc.
     48       auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i]));
     49       if (val < 0) {
     50         return errors::InvalidArgument("Size computed for dim ", i,
     51                                        " is negative: ", val);
     52       }
     53       output_dims.push_back(c->MakeDim(val));
     54     } else {
     55       output_dims.push_back(c->UnknownDim());
     56     }
     57   }
     58 
     59   c->set_output(0, c->MakeShape(output_dims));
     60   c->set_output(1, c->Vector(output_dims[1]));
     61   c->set_output(2, c->Vector(output_dims[2]));
     62   return Status::OK();
     63 }
     64 
     65 }  // namespace
     66 
     67 // --------------------------------------------------------------------------
     68 
     69 REGISTER_OP("AvgPool")
     70     .Input("value: T")
     71     .Output("output: T")
     72     .Attr("ksize: list(int) >= 4")
     73     .Attr("strides: list(int) >= 4")
     74     .Attr(GetPaddingAttrString())
     75     .Attr(GetConvnetDataFormatAttrString())
     76     .Attr("T: {half, bfloat16, float, double}")
     77     .SetShapeFn(shape_inference::AvgPoolShape);
     78 
     79 REGISTER_OP("AvgPoolGrad")
     80     .Input("orig_input_shape: int32")
     81     .Input("grad: T")
     82     .Output("output: T")
     83     .Attr("ksize: list(int) >= 4")
     84     .Attr("strides: list(int) >= 4")
     85     .Attr(GetPaddingAttrString())
     86     .Attr(GetConvnetDataFormatAttrString())
     87     .Attr("T: {half, bfloat16, float, double}")
     88     .SetShapeFn([](InferenceContext* c) {
     89       ShapeHandle s;
     90       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
     91       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
     92       c->set_output(0, s);
     93       return Status::OK();
     94     });
     95 
     96 // --------------------------------------------------------------------------
     97 
     98 REGISTER_OP("BatchNormWithGlobalNormalization")
     99     .Input("t: T")
    100     .Input("m: T")
    101     .Input("v: T")
    102     .Input("beta: T")
    103     .Input("gamma: T")
    104     .Output("result: T")
    105     .Attr("T: numbertype")
    106     .Attr("variance_epsilon: float")
    107     .Attr("scale_after_normalization: bool")
    108     .Deprecated(9, "Use tf.nn.batch_normalization()")
    109     .SetShapeFn([](InferenceContext* c) {
    110       ShapeHandle input;
    111       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
    112 
    113       DimensionHandle last_dim = c->Dim(input, 3);
    114       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
    115         ShapeHandle vec;
    116         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
    117         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
    118       }
    119 
    120       ShapeHandle out;
    121       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
    122       c->set_output(0, out);
    123       return Status::OK();
    124     });
    125 
    126 REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
    127     .Input("t: T")
    128     .Input("m: T")
    129     .Input("v: T")
    130     .Input("gamma: T")
    131     .Input("backprop: T")
    132     .Output("dx: T")
    133     .Output("dm: T")
    134     .Output("dv: T")
    135     .Output("db: T")
    136     .Output("dg: T")
    137     .Attr("T: numbertype")
    138     .Attr("variance_epsilon: float")
    139     .Attr("scale_after_normalization: bool")
    140     .Deprecated(9, "Use tf.nn.batch_normalization()")
    141     .SetShapeFn([](InferenceContext* c) {
    142       ShapeHandle input;
    143       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
    144       TF_RETURN_IF_ERROR(
    145           c->Merge(input, c->input(4), &input));  // with backprop
    146 
    147       DimensionHandle last_dim = c->Dim(input, 3);
    148       for (int i = 1; i < 4; ++i) {  // covers m, v, gamma
    149         ShapeHandle vec;
    150         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
    151         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
    152       }
    153 
    154       ShapeHandle dx;
    155       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
    156       c->set_output(0, dx);
    157 
    158       ShapeHandle vector_shape = c->Vector(last_dim);
    159       c->set_output(1, vector_shape);
    160       c->set_output(2, vector_shape);
    161       c->set_output(3, vector_shape);
    162       c->set_output(4, vector_shape);
    163       return Status::OK();
    164     });
    165 
    166 // --------------------------------------------------------------------------
    167 
    168 REGISTER_OP("FusedBatchNorm")
    169     .Input("x: T")
    170     .Input("scale: T")
    171     .Input("offset: T")
    172     .Input("mean: T")
    173     .Input("variance: T")
    174     .Output("y: T")
    175     .Output("batch_mean: T")
    176     .Output("batch_variance: T")
    177     .Output("reserve_space_1: T")
    178     .Output("reserve_space_2: T")
    179     .Attr("T: {float}")
    180     .Attr("epsilon: float = 0.0001")
    181     .Attr("data_format: string = 'NHWC'")
    182     .Attr("is_training: bool = true")
    183     .SetShapeFn(shape_inference::FusedBatchNormShape);
    184 
    185 REGISTER_OP("FusedBatchNormV2")
    186     .Input("x: T")
    187     .Input("scale: U")
    188     .Input("offset: U")
    189     .Input("mean: U")
    190     .Input("variance: U")
    191     .Output("y: T")
    192     .Output("batch_mean: U")
    193     .Output("batch_variance: U")
    194     .Output("reserve_space_1: U")
    195     .Output("reserve_space_2: U")
    196     .Attr("T: {half, bfloat16, float}")
    197     .Attr("U: {float}")
    198     .Attr("epsilon: float = 0.0001")
    199     .Attr("data_format: string = 'NHWC'")
    200     .Attr("is_training: bool = true")
    201     .SetShapeFn(shape_inference::FusedBatchNormShape);
    202 
    203 REGISTER_OP("FusedBatchNormGrad")
    204     .Input("y_backprop: T")
    205     .Input("x: T")
    206     .Input("scale: T")
    207     .Input("reserve_space_1: T")
    208     .Input("reserve_space_2: T")
    209     .Output("x_backprop: T")
    210     .Output("scale_backprop: T")
    211     .Output("offset_backprop: T")
    212     .Output("reserve_space_3: T")
    213     .Output("reserve_space_4: T")
    214     .Attr("T: {float}")
    215     .Attr("epsilon: float = 0.0001")
    216     .Attr("data_format: string = 'NHWC'")
    217     .Attr("is_training: bool = true")
    218     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
    219 
    220 REGISTER_OP("FusedBatchNormGradV2")
    221     .Input("y_backprop: T")
    222     .Input("x: T")
    223     .Input("scale: float")
    224     .Input("reserve_space_1: U")
    225     .Input("reserve_space_2: U")
    226     .Output("x_backprop: T")
    227     .Output("scale_backprop: U")
    228     .Output("offset_backprop: U")
    229     .Output("reserve_space_3: U")
    230     .Output("reserve_space_4: U")
    231     .Attr("T: {half, bfloat16, float}")
    232     .Attr("U: {float}")
    233     .Attr("epsilon: float = 0.0001")
    234     .Attr("data_format: string = 'NHWC'")
    235     .Attr("is_training: bool = true")
    236     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
    237 
    238 // --------------------------------------------------------------------------
    239 
    240 REGISTER_OP("BiasAdd")
    241     .Attr("T: numbertype")
    242     .Input("value: T")
    243     .Input("bias: T")
    244     .Attr(GetConvnetDataFormatAttrString())
    245     .Output("output: T")
    246     .SetShapeFn(shape_inference::BiasAddShape);
    247 // --------------------------------------------------------------------------
    248 
    249 REGISTER_OP("BiasAddGrad")
    250     .Attr("T: numbertype")
    251     .Input("out_backprop: T")
    252     .Attr(GetConvnetDataFormatAttrString())
    253     .Output("output: T")
    254     .SetShapeFn(shape_inference::BiasAddGradShape);
    255 // --------------------------------------------------------------------------
    256 
    257 REGISTER_OP("BiasAddV1")
    258     .Attr("T: numbertype")
    259     .Input("value: T")
    260     .Input("bias: T")
    261     .Output("output: T")
    262     .SetShapeFn(shape_inference::BiasAddShape);
    263 // --------------------------------------------------------------------------
    264 
    265 REGISTER_OP("Conv2D")
    266     .Input("input: T")
    267     .Input("filter: T")
    268     .Output("output: T")
    269     .Attr("T: {half, bfloat16, float}")
    270     .Attr("strides: list(int)")
    271     .Attr("use_cudnn_on_gpu: bool = true")
    272     .Attr(GetPaddingAttrString())
    273     .Attr(GetConvnetDataFormatAttrString())
    274     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    275     .SetShapeFn(shape_inference::Conv2DShape);
    276 
    277 REGISTER_OP("Conv2DBackpropInput")
    278     .Input("input_sizes: int32")
    279     .Input("filter: T")
    280     .Input("out_backprop: T")
    281     .Output("output: T")
    282     .Attr("T: {half, bfloat16, float}")
    283     .Attr("strides: list(int)")
    284     .Attr("use_cudnn_on_gpu: bool = true")
    285     .Attr(GetPaddingAttrString())
    286     .Attr(GetConvnetDataFormatAttrString())
    287     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    288     .SetShapeFn([](InferenceContext* c) {
    289       ShapeHandle s;
    290       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    291       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
    292       c->set_output(0, s);
    293       return Status::OK();
    294     });
    295 
    296 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
    297 // more general string attribute ('kernel_impl'?) that can be used to
    298 // select among several possible implementations.
    299 REGISTER_OP("Conv2DBackpropFilter")
    300     .Input("input: T")
    301     .Input("filter_sizes: int32")
    302     .Input("out_backprop: T")
    303     .Output("output: T")
    304     .Attr("T: {half, bfloat16, float}")
    305     .Attr("strides: list(int)")
    306     .Attr("use_cudnn_on_gpu: bool = true")
    307     .Attr(GetPaddingAttrString())
    308     .Attr(GetConvnetDataFormatAttrString())
    309     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    310     .SetShapeFn([](InferenceContext* c) {
    311       ShapeHandle s;
    312       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
    313       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
    314       c->set_output(0, s);
    315       return Status::OK();
    316     });
    317 
    318 namespace {
    319 
    320 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
    321   ShapeHandle input;
    322   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
    323 
    324   ShapeHandle resized = input;
    325   int paddings_index = 1;
    326   int filter_index = 2;
    327   if (has_resize) {
    328     paddings_index = 2;
    329     filter_index = 3;
    330 
    331     ShapeHandle unused_size;
    332     TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
    333 
    334     const Tensor* size = c->input_tensor(1);
    335     DimensionHandle new_height = c->UnknownDim();
    336     DimensionHandle new_width = c->UnknownDim();
    337     if (size != nullptr) {
    338       new_height = c->MakeDim(size->flat<int32>()(0));
    339       new_width = c->MakeDim(size->flat<int32>()(1));
    340     }
    341     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
    342     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
    343   }
    344 
    345   ShapeHandle paddings;
    346   TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
    347   TF_RETURN_IF_ERROR(
    348       c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
    349   TF_RETURN_IF_ERROR(
    350       c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
    351 
    352   const Tensor* paddings_t = c->input_tensor(paddings_index);
    353   ShapeHandle padded;
    354   if (paddings_t != nullptr) {
    355     std::vector<DimensionHandle> output_dims;
    356     for (int i = 0; i < 4; ++i) {
    357       DimensionHandle dim = c->Dim(resized, i);
    358       int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0));
    359       int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1));
    360       if (p0 < 0 || p1 < 0) {
    361         return errors::InvalidArgument("Paddings must be non-negative");
    362       }
    363 
    364       TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
    365       output_dims.push_back(dim);
    366     }
    367     padded = c->MakeShape(output_dims);
    368   } else {
    369     padded = c->UnknownShapeOfRank(4);
    370   }
    371 
    372   // Work out the convolution's effect with 'padded' as the input.
    373   ShapeHandle filter;
    374   TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
    375   std::vector<int32> strides;
    376   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    377   if (strides.size() != 4) {
    378     return errors::InvalidArgument(
    379         "Operation requires the stride attribute to contain 4 values, but ",
    380         "got: ", strides.size());
    381   }
    382 
    383   int32 stride_rows = strides[1];
    384   int32 stride_cols = strides[2];
    385 
    386   DimensionHandle batch_size_dim = c->Dim(padded, 0);
    387   DimensionHandle in_rows_dim = c->Dim(padded, 1);
    388   DimensionHandle in_cols_dim = c->Dim(padded, 2);
    389   DimensionHandle filter_rows_dim = c->Dim(filter, 0);
    390   DimensionHandle filter_cols_dim = c->Dim(filter, 1);
    391   DimensionHandle output_depth_dim = c->Dim(filter, 3);
    392 
    393   DimensionHandle unused;
    394   TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
    395 
    396   Padding padding;
    397   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    398 
    399   DimensionHandle output_rows, output_cols;
    400   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    401       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
    402   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
    403       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
    404 
    405   ShapeHandle output_shape = c->MakeShape(
    406       {batch_size_dim, output_rows, output_cols, output_depth_dim});
    407   c->set_output(0, output_shape);
    408   return Status::OK();
    409 }
    410 
    411 }  // namespace
    412 
    413 REGISTER_OP("DataFormatDimMap")
    414     .Input("x: T")
    415     .Output("y: T")
    416     .Attr("T: {int32, int64} = DT_INT32")
    417     .Attr("src_format: string = 'NHWC'")
    418     .Attr("dst_format: string = 'NCHW'")
    419     .SetShapeFn(shape_inference::UnchangedShape);
    420 
    421 REGISTER_OP("DataFormatVecPermute")
    422     .Input("x: T")
    423     .Output("y: T")
    424     .Attr("T: {int32, int64} = DT_INT32")
    425     .Attr("src_format: string = 'NHWC'")
    426     .Attr("dst_format: string = 'NCHW'")
    427     .SetShapeFn(shape_inference::UnchangedShape);
    428 
    429 REGISTER_OP("FusedResizeAndPadConv2D")
    430     .Input("input: T")
    431     .Input("size: int32")
    432     .Input("paddings: int32")
    433     .Input("filter: T")
    434     .Output("output: T")
    435     .Attr("T: {float}")
    436     .Attr("resize_align_corners: bool = false")
    437     .Attr(GetMirrorPadModeAttrString())
    438     .Attr("strides: list(int)")
    439     .Attr(GetPaddingAttrString())
    440     .SetShapeFn([](InferenceContext* c) {
    441       return CommonFusedConvCalculations(c, true /* has_resize */);
    442     });
    443 
    444 REGISTER_OP("FusedPadConv2D")
    445     .Input("input: T")
    446     .Input("paddings: int32")
    447     .Input("filter: T")
    448     .Output("output: T")
    449     .Attr("T: {float}")
    450     .Attr(GetMirrorPadModeAttrString())
    451     .Attr("strides: list(int)")
    452     .Attr(GetPaddingAttrString())
    453     .SetShapeFn([](InferenceContext* c) {
    454       return CommonFusedConvCalculations(c, false /* has_resize */);
    455     });
    456 
    457 // --------------------------------------------------------------------------
    458 
    459 REGISTER_OP("DepthwiseConv2dNative")
    460     .Input("input: T")
    461     .Input("filter: T")
    462     .Output("output: T")
    463     .Attr("T: {half, bfloat16, float, double}")
    464     .Attr("strides: list(int)")
    465     .Attr(GetPaddingAttrString())
    466     .Attr(GetConvnetDataFormatAttrString())
    467     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    468     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
    469 
    470 REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
    471     .Input("input_sizes: int32")
    472     .Input("filter: T")
    473     .Input("out_backprop: T")
    474     .Output("output: T")
    475     .Attr("T: {bfloat16, float, double}")
    476     .Attr("strides: list(int)")
    477     .Attr(GetPaddingAttrString())
    478     .Attr(GetConvnetDataFormatAttrString())
    479     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    480     .SetShapeFn([](InferenceContext* c) {
    481       ShapeHandle s;
    482       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    483       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
    484       c->set_output(0, s);
    485       return Status::OK();
    486     });
    487 
    488 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
    489     .Input("input: T")
    490     .Input("filter_sizes: int32")
    491     .Input("out_backprop: T")
    492     .Output("output: T")
    493     .Attr("T: {bfloat16, float, double}")
    494     .Attr("strides: list(int)")
    495     .Attr(GetPaddingAttrString())
    496     .Attr(GetConvnetDataFormatAttrString())
    497     .Attr("dilations: list(int) = [1, 1, 1, 1]")
    498     .SetShapeFn([](InferenceContext* c) {
    499       ShapeHandle s;
    500       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
    501       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
    502       c->set_output(0, s);
    503       return Status::OK();
    504     });
    505 
    506 // --------------------------------------------------------------------------
    507 REGISTER_OP("Conv3D")
    508     .Input("input: T")
    509     .Input("filter: T")
    510     .Output("output: T")
    511     .Attr("T: {half, bfloat16, float, double}")
    512     .Attr("strides: list(int) >= 5")
    513     .Attr(GetPaddingAttrString())
    514     .Attr(GetConvnet3dDataFormatAttrString())
    515     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
    516     .SetShapeFn(shape_inference::Conv3DShape);
    517 
    518 REGISTER_OP("Conv3DBackpropInput")
    519     .Input("input: T")
    520     .Input("filter: T")
    521     .Input("out_backprop: T")
    522     .Output("output: T")
    523     .Attr("T: {half, float, double}")
    524     .Attr("strides: list(int) >= 5")
    525     .Attr(GetPaddingAttrString())
    526     .Deprecated(10, "Use Conv3DBackpropInputV2")
    527     .SetShapeFn([](InferenceContext* c) {
    528       return UnchangedShapeWithRank(c, 5);
    529     });
    530 
    531 REGISTER_OP("Conv3DBackpropFilter")
    532     .Input("input: T")
    533     .Input("filter: T")
    534     .Input("out_backprop: T")
    535     .Output("output: T")
    536     .Attr("T: {half, float, double}")
    537     .Attr("strides: list(int) >= 5")
    538     .Attr(GetPaddingAttrString())
    539     .Deprecated(10, "Use Conv3DBackpropFilterV2")
    540     .SetShapeFn([](InferenceContext* c) {
    541       ShapeHandle out;
    542       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
    543       c->set_output(0, out);
    544       return Status::OK();
    545     });
    546 
    547 REGISTER_OP("Conv3DBackpropInputV2")
    548     .Input("input_sizes: int32")
    549     .Input("filter: T")
    550     .Input("out_backprop: T")
    551     .Output("output: T")
    552     .Attr("T: {half, bfloat16, float, double}")
    553     .Attr("strides: list(int) >= 5")
    554     .Attr(GetPaddingAttrString())
    555     .Attr(GetConvnet3dDataFormatAttrString())
    556     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
    557     .SetShapeFn([](InferenceContext* c) {
    558       ShapeHandle s;
    559       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    560       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
    561       c->set_output(0, s);
    562       return Status::OK();
    563     });
    564 
    565 REGISTER_OP("Conv3DBackpropFilterV2")
    566     .Input("input: T")
    567     .Input("filter_sizes: int32")
    568     .Input("out_backprop: T")
    569     .Output("output: T")
    570     .Attr("T: {half, bfloat16, float, double}")
    571     .Attr("strides: list(int) >= 5")
    572     .Attr(GetPaddingAttrString())
    573     .Attr(GetConvnet3dDataFormatAttrString())
    574     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
    575     .SetShapeFn([](InferenceContext* c) {
    576       ShapeHandle s;
    577       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
    578       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
    579       c->set_output(0, s);
    580       return Status::OK();
    581     });
    582 
    583 // --------------------------------------------------------------------------
    584 
    585 REGISTER_OP("AvgPool3D")
    586     .Input("input: T")
    587     .Output("output: T")
    588     .Attr("ksize: list(int) >= 5")
    589     .Attr("strides: list(int) >= 5")
    590     .Attr(GetPaddingAttrString())
    591     .Attr(GetConvnet3dDataFormatAttrString())
    592     .Attr("T: {bfloat16, float, double}")
    593     .SetShapeFn(shape_inference::Pool3DShape);
    594 
    595 REGISTER_OP("AvgPool3DGrad")
    596     .Input("orig_input_shape: int32")
    597     .Input("grad: T")
    598     .Output("output: T")
    599     .Attr("ksize: list(int) >= 5")
    600     .Attr("strides: list(int) >= 5")
    601     .Attr(GetPaddingAttrString())
    602     .Attr(GetConvnet3dDataFormatAttrString())
    603     .Attr("T: {bfloat16, float, double}")
    604     .SetShapeFn([](InferenceContext* c) {
    605       ShapeHandle s;
    606       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
    607       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
    608       c->set_output(0, s);
    609       return Status::OK();
    610     });
    611 
    612 // --------------------------------------------------------------------------
    613 
    614 REGISTER_OP("MaxPool3D")
    615     .Input("input: T")
    616     .Output("output: T")
    617     .Attr("ksize: list(int) >= 5")
    618     .Attr("strides: list(int) >= 5")
    619     .Attr(GetPaddingAttrString())
    620     .Attr(GetConvnet3dDataFormatAttrString())
    621     .Attr("T: {bfloat16, float}")
    622     .SetShapeFn(shape_inference::Pool3DShape);
    623 
    624 REGISTER_OP("MaxPool3DGrad")
    625     .Input("orig_input: TInput")
    626     .Input("orig_output: TInput")
    627     .Input("grad: T")
    628     .Output("output: T")
    629     .Attr("ksize: list(int) >= 5")
    630     .Attr("strides: list(int) >= 5")
    631     .Attr(GetPaddingAttrString())
    632     .Attr(GetConvnet3dDataFormatAttrString())
    633     .Attr("T: {bfloat16, float} = DT_FLOAT")
    634     .Attr("TInput: {bfloat16, float} = DT_FLOAT")
    635     .SetShapeFn([](InferenceContext* c) {
    636       return UnchangedShapeWithRank(c, 5);
    637     });
    638 
    639 REGISTER_OP("MaxPool3DGradGrad")
    640     .Input("orig_input: T")
    641     .Input("orig_output: T")
    642     .Input("grad: T")
    643     .Output("output: T")
    644     .Attr("ksize: list(int) >= 5 ")
    645     .Attr("strides: list(int) >= 5")
    646     .Attr(GetPaddingAttrString())
    647     .Attr(GetConvnet3dDataFormatAttrString())
    648     .Attr("T: {float}")
    649     .SetShapeFn([](InferenceContext* c) {
    650       TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
    651       ShapeHandle unused;
    652       // Validate 'orig_input' is the same shape as 'grad'
    653       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
    654       // Validate 'orig_output' is same shape as 'output'
    655       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
    656       return Status::OK();
    657     });
    658 
    659 // --------------------------------------------------------------------------
    660 
    661 REGISTER_OP("L2Loss")
    662     .Input("t: T")
    663     .Output("output: T")
    664     .Attr("T: {half, bfloat16, float, double}")
    665     .SetShapeFn(shape_inference::ScalarShape);
    666 
    667 // --------------------------------------------------------------------------
    668 
    669 REGISTER_OP("LRN")
    670     .Input("input: T")
    671     .Output("output: T")
    672     .Attr("depth_radius: int = 5")
    673     .Attr("bias: float = 1.0")
    674     .Attr("alpha: float = 1.0")
    675     .Attr("beta: float = 0.5")
    676     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
    677     .SetShapeFn([](InferenceContext* c) {
    678       return UnchangedShapeWithRank(c, 4);
    679     });
    680 
    681 REGISTER_OP("LRNGrad")
    682     .Input("input_grads: T")
    683     .Input("input_image: T")
    684     .Input("output_image: T")
    685     .Output("output: T")
    686     .Attr("depth_radius: int = 5")
    687     .Attr("bias: float = 1.0")
    688     .Attr("alpha: float = 1.0")
    689     .Attr("beta: float = 0.5")
    690     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
    691     .SetShapeFn([](InferenceContext* c) {
    692       ShapeHandle s;
    693       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
    694       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
    695       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
    696       c->set_output(0, s);
    697       return Status::OK();
    698     });
    699 
    700 // --------------------------------------------------------------------------
    701 
    702 REGISTER_OP("MaxPool")
    703     .Attr(
    704         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
    705         "uint16, qint8} = DT_FLOAT")
    706     .Attr("ksize: list(int) >= 4")
    707     .Attr("strides: list(int) >= 4")
    708     .Attr(GetPaddingAttrString())
    709     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
    710     .Input("input: T")
    711     .Output("output: T")
    712     .SetShapeFn(shape_inference::MaxPoolShape);
    713 
    714 REGISTER_OP("MaxPoolV2")
    715     .Attr(
    716         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
    717         "uint16, qint8} = DT_FLOAT")
    718     .Attr(GetPaddingAttrString())
    719     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
    720     .Input("input: T")
    721     .Input("ksize: int32")
    722     .Input("strides: int32")
    723     .Output("output: T")
    724     .SetShapeFn([](InferenceContext* c) {
    725       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
    726       return Status::OK();
    727     });
    728 
    729 REGISTER_OP("MaxPoolGrad")
    730     .Attr("ksize: list(int) >= 4")
    731     .Attr("strides: list(int) >= 4")
    732     .Attr(GetPaddingAttrString())
    733     .Attr(GetConvnetDataFormatAttrString())
    734     .Input("orig_input: T")
    735     .Input("orig_output: T")
    736     .Input("grad: T")
    737     .Output("output: T")
    738     .Attr("T: realnumbertype = DT_FLOAT")
    739     .SetShapeFn([](InferenceContext* c) {
    740       return UnchangedShapeWithRank(c, 4);
    741     });
    742 
    743 REGISTER_OP("MaxPoolGradV2")
    744     .Attr(GetPaddingAttrString())
    745     .Attr(GetConvnetDataFormatAttrString())
    746     .Input("orig_input: T")
    747     .Input("orig_output: T")
    748     .Input("grad: T")
    749     .Input("ksize: int32")
    750     .Input("strides: int32")
    751     .Output("output: T")
    752     .Attr("T: realnumbertype = DT_FLOAT")
    753     .SetShapeFn([](InferenceContext* c) {
    754       return UnchangedShapeWithRank(c, 4);
    755     });
    756 
    757 REGISTER_OP("MaxPoolGradGrad")
    758     .Attr("ksize: list(int) >= 4")
    759     .Attr("strides: list(int) >= 4")
    760     .Attr(GetPaddingAttrString())
    761     .Attr(GetConvnetDataFormatAttrString())
    762     .Input("orig_input: T")
    763     .Input("orig_output: T")
    764     .Input("grad: T")
    765     .Output("output: T")
    766     .Attr("T: realnumbertype")
    767     .SetShapeFn([](InferenceContext* c) {
    768       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
    769       ShapeHandle unused;
    770       // Validate 'orig_input' is the same shape as 'grad'
    771       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
    772       // Validate 'orig_output' is same shape as 'output'
    773       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
    774       return Status::OK();
    775     });
    776 
    777 REGISTER_OP("MaxPoolGradGradV2")
    778     .Attr(GetPaddingAttrString())
    779     .Attr(GetConvnetDataFormatAttrString())
    780     .Input("orig_input: T")
    781     .Input("orig_output: T")
    782     .Input("grad: T")
    783     .Input("ksize: int32")
    784     .Input("strides: int32")
    785     .Output("output: T")
    786     .Attr("T: realnumbertype")
    787     .SetShapeFn([](InferenceContext* c) {
    788       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
    789       ShapeHandle unused;
    790       // Validate 'orig_input' is the same shape as 'grad'
    791       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
    792       // Validate 'orig_output' is same shape as 'output'
    793       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
    794       return Status::OK();
    795     });
    796 
    797 REGISTER_OP("MaxPoolWithArgmax")
    798     .Attr("ksize: list(int) >= 4")
    799     .Attr("strides: list(int) >= 4")
    800     .Attr("Targmax: {int32, int64} = DT_INT64")
    801     .Attr(GetPaddingAttrString())
    802     .Input("input: T")
    803     .Output("output: T")
    804     .Output("argmax: Targmax")
    805     .Attr("T: realnumbertype")
    806     .SetShapeFn([](InferenceContext* c) {
    807       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
    808       c->set_output(1, c->output(0));
    809       return Status::OK();
    810     });
    811 
    812 REGISTER_OP("MaxPoolGradWithArgmax")
    813     .Attr("ksize: list(int) >= 4")
    814     .Attr("strides: list(int) >= 4")
    815     .Attr(GetPaddingAttrString())
    816     .Attr("Targmax: {int32, int64}")
    817     .Input("input: T")
    818     .Input("grad: T")
    819     .Input("argmax: Targmax")
    820     .Output("output: T")
    821     .Attr("T: realnumbertype")
    822     .SetShapeFn([](InferenceContext* c) {
    823       return UnchangedShapeWithRank(c, 4);
    824     });
    825 
    826 REGISTER_OP("MaxPoolGradGradWithArgmax")
    827     .Attr("ksize: list(int) >= 4")
    828     .Attr("strides: list(int) >= 4")
    829     .Attr(GetPaddingAttrString())
    830     .Attr("Targmax: {int32, int64}")
    831     .Input("input: T")
    832     .Input("grad: T")
    833     .Input("argmax: Targmax")
    834     .Output("output: T")
    835     .Attr("T: realnumbertype")
    836     .SetShapeFn([](InferenceContext* c) {
    837       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
    838       ShapeHandle unused;
    839       // Validate 'orig_input' is the same shape as 'grad'
    840       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
    841       // Validate 'argmax' is same shape as 'output'
    842       TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
    843       return Status::OK();
    844     });
    845 
    846 // --------------------------------------------------------------------------
    847 
    848 REGISTER_OP("Dilation2D")
    849     .Input("input: T")
    850     .Input("filter: T")
    851     .Output("output: T")
    852     .Attr("T: realnumbertype")
    853     .Attr("strides: list(int) >= 4")
    854     .Attr("rates: list(int) >= 4")
    855     .Attr(GetPaddingAttrString())
    856     .SetShapeFn([](InferenceContext* c) {
    857       ShapeHandle input_shape;
    858       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
    859       ShapeHandle filter_shape;
    860       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
    861 
    862       std::vector<int32> strides;
    863       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
    864       if (strides.size() != 4) {
    865         return errors::InvalidArgument(
    866             "Dilation2D requires the stride attribute to contain 4 values, but "
    867             "got: ",
    868             strides.size());
    869       }
    870 
    871       std::vector<int32> rates;
    872       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
    873       if (rates.size() != 4) {
    874         return errors::InvalidArgument(
    875             "Dilation2D requires the rates attribute to contain 4 values, but "
    876             "got: ",
    877             rates.size());
    878       }
    879 
    880       int32 stride_rows = strides[1];
    881       int32 stride_cols = strides[2];
    882 
    883       int32 rate_rows = rates[1];
    884       int32 rate_cols = rates[2];
    885 
    886       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
    887       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
    888       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
    889       DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
    890       DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
    891       DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
    892 
    893       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
    894           !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
    895         ShapeHandle output_shape =
    896             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
    897                           InferenceContext::kUnknownDim, output_depth_dim});
    898         c->set_output(0, output_shape);
    899         return Status::OK();
    900       }
    901       DimensionHandle unused;
    902       TF_RETURN_IF_ERROR(
    903           c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
    904 
    905       auto in_rows = c->Value(in_rows_dim);
    906       auto in_cols = c->Value(in_cols_dim);
    907       auto filter_rows = c->Value(filter_rows_dim);
    908       auto filter_cols = c->Value(filter_cols_dim);
    909       auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
    910       auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
    911 
    912       Padding padding;
    913       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
    914 
    915       int64 output_rows, output_cols;
    916       int64 padding_before, padding_after;
    917       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
    918           in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
    919           &padding_before, &padding_after));
    920       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
    921           in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
    922           &padding_before, &padding_after));
    923 
    924       ShapeHandle output_shape = c->MakeShape(
    925           {batch_size_dim, output_rows, output_cols, output_depth_dim});
    926       c->set_output(0, output_shape);
    927       return Status::OK();
    928     });
    929 
    930 REGISTER_OP("Dilation2DBackpropInput")
    931     .Input("input: T")
    932     .Input("filter: T")
    933     .Input("out_backprop: T")
    934     .Output("in_backprop: T")
    935     .Attr("T: realnumbertype")
    936     .Attr("strides: list(int) >= 4")
    937     .Attr("rates: list(int) >= 4")
    938     .Attr(GetPaddingAttrString())
    939     .SetShapeFn(shape_inference::UnchangedShape);
    940 
    941 REGISTER_OP("Dilation2DBackpropFilter")
    942     .Input("input: T")
    943     .Input("filter: T")
    944     .Input("out_backprop: T")
    945     .Output("filter_backprop: T")
    946     .Attr("T: realnumbertype")
    947     .Attr("strides: list(int) >= 4")
    948     .Attr("rates: list(int) >= 4")
    949     .Attr(GetPaddingAttrString())
    950     .SetShapeFn([](InferenceContext* c) {
    951       c->set_output(0, c->input(1));
    952       return Status::OK();
    953     });
    954 
    955 // --------------------------------------------------------------------------
    956 
    957 REGISTER_OP("Relu")
    958     .Input("features: T")
    959     .Output("activations: T")
    960     .Attr("T: realnumbertype")
    961     .SetShapeFn(shape_inference::UnchangedShape);
    962 
    963 REGISTER_OP("ReluGrad")
    964     .Input("gradients: T")
    965     .Input("features: T")
    966     .Output("backprops: T")
    967     .Attr("T: realnumbertype")
    968     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
    969 
    970 REGISTER_OP("Relu6")
    971     .Input("features: T")
    972     .Output("activations: T")
    973     .Attr("T: realnumbertype")
    974     .SetShapeFn(shape_inference::UnchangedShape);
    975 
    976 REGISTER_OP("Relu6Grad")
    977     .Input("gradients: T")
    978     .Input("features: T")
    979     .Output("backprops: T")
    980     .Attr("T: realnumbertype")
    981     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
    982 
    983 REGISTER_OP("Elu")
    984     .Input("features: T")
    985     .Output("activations: T")
    986     .Attr("T: {half, bfloat16, float, double}")
    987     .SetShapeFn(shape_inference::UnchangedShape);
    988 
    989 REGISTER_OP("EluGrad")
    990     .Input("gradients: T")
    991     .Input("outputs: T")
    992     .Output("backprops: T")
    993     .Attr("T: {half, bfloat16, float, double}")
    994     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
    995 
    996 REGISTER_OP("Selu")
    997     .Input("features: T")
    998     .Output("activations: T")
    999     .Attr("T: {half, bfloat16, float, double}")
   1000     .SetShapeFn(shape_inference::UnchangedShape);
   1001 
   1002 REGISTER_OP("SeluGrad")
   1003     .Input("gradients: T")
   1004     .Input("outputs: T")
   1005     .Output("backprops: T")
   1006     .Attr("T: {half, bfloat16, float, double}")
   1007     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
   1008 
   1009 REGISTER_OP("Softplus")
   1010     .Input("features: T")
   1011     .Output("activations: T")
   1012     .Attr("T: realnumbertype")
   1013     .SetShapeFn(shape_inference::UnchangedShape);
   1014 
   1015 REGISTER_OP("SoftplusGrad")
   1016     .Input("gradients: T")
   1017     .Input("features: T")
   1018     .Output("backprops: T")
   1019     .Attr("T: realnumbertype")
   1020     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
   1021 
   1022 REGISTER_OP("Softsign")
   1023     .Input("features: T")
   1024     .Output("activations: T")
   1025     .Attr("T: realnumbertype")
   1026     .SetShapeFn(shape_inference::UnchangedShape);
   1027 
   1028 REGISTER_OP("SoftsignGrad")
   1029     .Input("gradients: T")
   1030     .Input("features: T")
   1031     .Output("backprops: T")
   1032     .Attr("T: realnumbertype")
   1033     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
   1034 
   1035 // --------------------------------------------------------------------------
   1036 
   1037 REGISTER_OP("Softmax")
   1038     .Input("logits: T")
   1039     .Output("softmax: T")
   1040     .Attr("T: {half, bfloat16, float, double}")
   1041     .SetShapeFn([](InferenceContext* c) {
   1042       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
   1043     });
   1044 
   1045 // --------------------------------------------------------------------------
   1046 
   1047 REGISTER_OP("LogSoftmax")
   1048     .Input("logits: T")
   1049     .Output("logsoftmax: T")
   1050     .Attr("T: {half, bfloat16, float, double}")
   1051     .SetShapeFn([](InferenceContext* c) {
   1052       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
   1053     });
   1054 
   1055 // --------------------------------------------------------------------------
   1056 
   1057 REGISTER_OP("SoftmaxCrossEntropyWithLogits")
   1058     .Input("features: T")
   1059     .Input("labels: T")
   1060     .Output("loss: T")
   1061     .Output("backprop: T")
   1062     .Attr("T: {half, bfloat16, float, double}")
   1063     .SetShapeFn([](InferenceContext* c) {
   1064       ShapeHandle input;
   1065       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
   1066       TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input));
   1067 
   1068       DimensionHandle batch_size = c->Dim(input, 0);
   1069       c->set_output(0, c->Vector(batch_size));
   1070       c->set_output(1, input);
   1071       return Status::OK();
   1072     });
   1073 
   1074 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
   1075     .Input("features: T")
   1076     .Input("labels: Tlabels")
   1077     .Output("loss: T")
   1078     .Output("backprop: T")
   1079     .Attr("T: {half, bfloat16, float, double}")
   1080     .Attr("Tlabels: {int32, int64} = DT_INT64")
   1081     .SetShapeFn([](InferenceContext* c) {
   1082       ShapeHandle features;
   1083       ShapeHandle labels;
   1084       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
   1085       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
   1086 
   1087       DimensionHandle batch_size;
   1088       TF_RETURN_IF_ERROR(
   1089           c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
   1090       TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
   1091 
   1092       c->set_output(0, c->Vector(batch_size));
   1093       c->set_output(1, features);
   1094       return Status::OK();
   1095     });
   1096 
   1097 // --------------------------------------------------------------------------
   1098 
   1099 REGISTER_OP("InTopK")
   1100     .Input("predictions: float")
   1101     .Input("targets: T")
   1102     .Output("precision: bool")
   1103     .Attr("k: int")
   1104     .Attr("T: {int32, int64} = DT_INT32")
   1105     .SetShapeFn([](InferenceContext* c) {
   1106       ShapeHandle predictions;
   1107       ShapeHandle targets;
   1108       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
   1109       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
   1110       DimensionHandle batch_size;
   1111       TF_RETURN_IF_ERROR(
   1112           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
   1113       c->set_output(0, c->Vector(batch_size));
   1114       return Status::OK();
   1115     });
   1116 
   1117 // This is the same as `InTopK`, but takes `k` as in input rather than an attr.
   1118 REGISTER_OP("InTopKV2")
   1119     .Input("predictions: float")
   1120     .Input("targets: T")
   1121     .Input("k: T")
   1122     .Output("precision: bool")
   1123     .Attr("T: {int32, int64} = DT_INT32")
   1124     .SetShapeFn([](InferenceContext* c) {
   1125       ShapeHandle predictions;
   1126       ShapeHandle targets;
   1127       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
   1128       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
   1129       DimensionHandle batch_size;
   1130       TF_RETURN_IF_ERROR(
   1131           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
   1132       c->set_output(0, c->Vector(batch_size));
   1133       return Status::OK();
   1134     });
   1135 
   1136 namespace {
   1137 
   1138 Status TopKShapeFn(InferenceContext* c) {
   1139   ShapeHandle input;
   1140   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
   1141 
   1142   // Get the k value, either from input tensor or attribute.
   1143   DimensionHandle k_dim;
   1144   if (c->num_inputs() >= 2) {
   1145     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
   1146   } else {
   1147     int32 k;
   1148     TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
   1149     if (k < 0) {
   1150       return errors::InvalidArgument("Need k >= 0, got ", k);
   1151     }
   1152     k_dim = c->MakeDim(k);
   1153   }
   1154 
   1155   DimensionHandle last_dim = c->Dim(input, -1);
   1156   if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
   1157       c->Value(last_dim) < c->Value(k_dim)) {
   1158     return errors::InvalidArgument(
   1159         "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
   1160         c->Value(last_dim));
   1161   }
   1162 
   1163   // Replace last_dim with k_dim.
   1164   ShapeHandle s;
   1165   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
   1166   TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
   1167   c->set_output(0, s);
   1168   c->set_output(1, s);
   1169   return Status::OK();
   1170 }
   1171 
   1172 }  // namespace
   1173 
   1174 REGISTER_OP("TopK")
   1175     .Input("input: T")
   1176     .Output("values: T")
   1177     .Output("indices: int32")
   1178     .Attr("k: int >= 0")
   1179     .Attr("sorted: bool = true")
   1180     .Attr("T: realnumbertype")
   1181     .Deprecated(7, "Use TopKV2 instead")
   1182     .SetShapeFn(TopKShapeFn);
   1183 
   1184 // This is the same as `TopK`, but takes `k` as in input rather than an attr.
   1185 REGISTER_OP("TopKV2")
   1186     .Input("input: T")
   1187     .Input("k: int32")
   1188     .Output("values: T")
   1189     .Output("indices: int32")
   1190     .Attr("sorted: bool = true")
   1191     .Attr("T: realnumbertype")
   1192     .SetShapeFn(TopKShapeFn);
   1193 
   1194 // --------------------------------------------------------------------------
   1195 
   1196 REGISTER_OP("NthElement")
   1197     .Input("input: T")
   1198     .Input("n: int32")
   1199     .Output("values: T")
   1200     .Attr("reverse: bool = false")
   1201     .Attr("T: realnumbertype")
   1202     .SetShapeFn([](InferenceContext* c) {
   1203       ShapeHandle input;
   1204       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
   1205 
   1206       // Get the n value from input tensor, and make sure which is a scalar.
   1207       DimensionHandle n_dim;
   1208       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
   1209 
   1210       // The last dimension of input tensor must be greater than N.
   1211       DimensionHandle last_dim = c->Dim(input, -1);
   1212       if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
   1213           c->Value(last_dim) <= c->Value(n_dim)) {
   1214         return errors::InvalidArgument(
   1215             "Input must have last dimension > n = ", c->Value(n_dim),
   1216             " but is ", c->Value(last_dim));
   1217       }
   1218 
   1219       // Reduce last_dim for output tensor
   1220       ShapeHandle s;
   1221       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
   1222       c->set_output(0, s);
   1223       return Status::OK();
   1224     });
   1225 
   1226 // --------------------------------------------------------------------------
   1227 
   1228 REGISTER_OP("FractionalMaxPool")
   1229     .Input("value: T")
   1230     .Output("output: T")
   1231     .Output("row_pooling_sequence: int64")
   1232     .Output("col_pooling_sequence: int64")
   1233     .Attr("pooling_ratio: list(float) >=4")
   1234     .Attr("pseudo_random: bool = false")
   1235     .Attr("overlapping: bool = false")
   1236     .Attr("deterministic: bool = false")
   1237     .Attr("seed: int = 0")
   1238     .Attr("seed2: int = 0")
   1239     .Attr("T: {float, double, int32, int64}")
   1240     .SetShapeFn(FractionalPoolShapeFn);
   1241 
   1242 REGISTER_OP("FractionalMaxPoolGrad")
   1243     .Input("orig_input: T")
   1244     .Input("orig_output: T")
   1245     .Input("out_backprop: T")
   1246     .Input("row_pooling_sequence: int64")
   1247     .Input("col_pooling_sequence: int64")
   1248     .Output("output: T")
   1249     .Attr("overlapping: bool = false")
   1250     .Attr("T: {float, double, int32, int64}")
   1251     .SetShapeFn([](InferenceContext* c) {
   1252       return shape_inference::UnchangedShapeWithRank(c, 4);
   1253     });
   1254 
   1255 // --------------------------------------------------------------------------
   1256 
   1257 REGISTER_OP("FractionalAvgPool")
   1258     .Input("value: T")
   1259     .Output("output: T")
   1260     .Output("row_pooling_sequence: int64")
   1261     .Output("col_pooling_sequence: int64")
   1262     .Attr("pooling_ratio: list(float) >=4")
   1263     .Attr("pseudo_random: bool = false")
   1264     .Attr("overlapping: bool = false")
   1265     .Attr("deterministic: bool = false")
   1266     .Attr("seed: int = 0")
   1267     .Attr("seed2: int = 0")
   1268     .Attr("T: {float, double, int32, int64}")
   1269     .SetShapeFn(FractionalPoolShapeFn);
   1270 
   1271 REGISTER_OP("FractionalAvgPoolGrad")
   1272     .Input("orig_input_tensor_shape: int64")
   1273     .Input("out_backprop: T")
   1274     .Input("row_pooling_sequence: int64")
   1275     .Input("col_pooling_sequence: int64")
   1276     .Output("output: T")
   1277     .Attr("overlapping: bool = false")
   1278     .Attr("T: {float, double, int32, int64}")
   1279     .SetShapeFn([](InferenceContext* c) {
   1280       if (c->input_tensor(0) != nullptr) {
   1281         ShapeHandle out;
   1282         TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
   1283         c->set_output(0, out);
   1284       } else {
   1285         c->set_output(0, c->UnknownShapeOfRank(4));
   1286       }
   1287       return Status::OK();
   1288     });
   1289 
   1290 REGISTER_OP("QuantizedAvgPool")
   1291     .Input("input: T")
   1292     .Input("min_input: float")
   1293     .Input("max_input: float")
   1294     .Output("output: T")
   1295     .Output("min_output: float")
   1296     .Output("max_output: float")
   1297     .Attr("T: quantizedtype")
   1298     .Attr("ksize: list(int)")
   1299     .Attr("strides: list(int)")
   1300     .Attr(GetPaddingAttrString())
   1301     .SetShapeFn([](InferenceContext* c) {
   1302       TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
   1303       ShapeHandle unused;
   1304       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1305       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1306       c->set_output(1, c->Scalar());
   1307       c->set_output(2, c->Scalar());
   1308       return Status::OK();
   1309     });
   1310 
   1311 REGISTER_OP("QuantizedBiasAdd")
   1312     .Input("input: T1")
   1313     .Input("bias: T2")
   1314     .Input("min_input: float")
   1315     .Input("max_input: float")
   1316     .Input("min_bias: float")
   1317     .Input("max_bias: float")
   1318     .Output("output: out_type")
   1319     .Output("min_out: float")
   1320     .Output("max_out: float")
   1321     .Attr("T1: quantizedtype")
   1322     .Attr("T2: quantizedtype")
   1323     .Attr("out_type: quantizedtype")
   1324     .SetShapeFn([](InferenceContext* c) {
   1325       TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
   1326       ShapeHandle unused;
   1327       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1328       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1329       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1330       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
   1331       c->set_output(1, c->Scalar());
   1332       c->set_output(2, c->Scalar());
   1333       return Status::OK();
   1334     });
   1335 
   1336 REGISTER_OP("QuantizedConv2D")
   1337     .Input("input: Tinput")
   1338     .Input("filter: Tfilter")
   1339     .Input("min_input: float")
   1340     .Input("max_input: float")
   1341     .Input("min_filter: float")
   1342     .Input("max_filter: float")
   1343     .Output("output: out_type")
   1344     .Output("min_output: float")
   1345     .Output("max_output: float")
   1346     .Attr("Tinput: quantizedtype")
   1347     .Attr("Tfilter: quantizedtype")
   1348     .Attr("out_type: quantizedtype = DT_QINT32")
   1349     .Attr("strides: list(int)")
   1350     .Attr(GetPaddingAttrString())
   1351     .Attr("dilations: list(int) = [1, 1, 1, 1]")
   1352     .SetShapeFn([](InferenceContext* c) {
   1353       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
   1354       ShapeHandle unused;
   1355       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1356       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
   1357       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
   1358       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
   1359       c->set_output(1, c->Scalar());
   1360       c->set_output(2, c->Scalar());
   1361       return Status::OK();
   1362     });
   1363 
   1364 REGISTER_OP("QuantizedMaxPool")
   1365     .Input("input: T")
   1366     .Input("min_input: float")
   1367     .Input("max_input: float")
   1368     .Output("output: T")
   1369     .Output("min_output: float")
   1370     .Output("max_output: float")
   1371     .Attr("T: quantizedtype")
   1372     .Attr("ksize: list(int)")
   1373     .Attr("strides: list(int)")
   1374     .Attr(GetPaddingAttrString())
   1375     .SetShapeFn([](InferenceContext* c) {
   1376       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
   1377       ShapeHandle unused;
   1378       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1379       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1380       c->set_output(1, c->Scalar());
   1381       c->set_output(2, c->Scalar());
   1382       return Status::OK();
   1383     });
   1384 
   1385 REGISTER_OP("QuantizedRelu")
   1386     .Input("features: Tinput")
   1387     .Input("min_features: float")
   1388     .Input("max_features: float")
   1389     .Output("activations: out_type")
   1390     .Output("min_activations: float")
   1391     .Output("max_activations: float")
   1392     .Attr("Tinput: quantizedtype")
   1393     .Attr("out_type: quantizedtype = DT_QUINT8")
   1394     .SetShapeFn([](InferenceContext* c) {
   1395       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1396       ShapeHandle unused;
   1397       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1398       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1399       c->set_output(1, c->Scalar());
   1400       c->set_output(2, c->Scalar());
   1401       return Status::OK();
   1402     });
   1403 
   1404 REGISTER_OP("QuantizedRelu6")
   1405     .Input("features: Tinput")
   1406     .Input("min_features: float")
   1407     .Input("max_features: float")
   1408     .Output("activations: out_type")
   1409     .Output("min_activations: float")
   1410     .Output("max_activations: float")
   1411     .Attr("Tinput: quantizedtype")
   1412     .Attr("out_type: quantizedtype = DT_QUINT8")
   1413     .SetShapeFn([](InferenceContext* c) {
   1414       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1415       ShapeHandle unused;
   1416       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1417       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1418       c->set_output(1, c->Scalar());
   1419       c->set_output(2, c->Scalar());
   1420       return Status::OK();
   1421     });
   1422 
   1423 REGISTER_OP("QuantizedReluX")
   1424     .Input("features: Tinput")
   1425     .Input("max_value: float")
   1426     .Input("min_features: float")
   1427     .Input("max_features: float")
   1428     .Output("activations: out_type")
   1429     .Output("min_activations: float")
   1430     .Output("max_activations: float")
   1431     .Attr("Tinput: quantizedtype")
   1432     .Attr("out_type: quantizedtype = DT_QUINT8")
   1433     .SetShapeFn([](InferenceContext* c) {
   1434       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
   1435       ShapeHandle unused;
   1436       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
   1437       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
   1438       c->set_output(1, c->Scalar());
   1439       c->set_output(2, c->Scalar());
   1440       return Status::OK();
   1441     });
   1442 
   1443 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
   1444     .Input("t: Tinput")
   1445     .Input("t_min: float")
   1446     .Input("t_max: float")
   1447     .Input("m: Tinput")
   1448     .Input("m_min: float")
   1449     .Input("m_max: float")
   1450     .Input("v: Tinput")
   1451     .Input("v_min: float")
   1452     .Input("v_max: float")
   1453     .Input("beta: Tinput")
   1454     .Input("beta_min: float")
   1455     .Input("beta_max: float")
   1456     .Input("gamma: Tinput")
   1457     .Input("gamma_min: float")
   1458     .Input("gamma_max: float")
   1459     .Output("result: out_type")
   1460     .Output("result_min: float")
   1461     .Output("result_max: float")
   1462     .Attr("Tinput: quantizedtype")
   1463     .Attr("out_type: quantizedtype")
   1464     .Attr("variance_epsilon: float")
   1465     .Attr("scale_after_normalization: bool")
   1466     .SetShapeFn([](InferenceContext* c) {
   1467       ShapeHandle input;
   1468       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
   1469 
   1470       DimensionHandle last_dim = c->Dim(input, 3);
   1471       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
   1472         ShapeHandle vec;
   1473         TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
   1474         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
   1475       }
   1476 
   1477       ShapeHandle out;
   1478       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
   1479       c->set_output(0, out);
   1480       c->set_output(1, c->Scalar());
   1481       c->set_output(2, c->Scalar());
   1482 
   1483       return Status::OK();
   1484     });
   1485 
   1486 #ifdef INTEL_MKL
   1487 REGISTER_OP("_MklConv2D")
   1488     .Input("input: T")
   1489     .Input("filter: T")
   1490     .Input("mkl_input: uint8")
   1491     .Input("mkl_filter: uint8")
   1492     .Output("output: T")
   1493     .Output("filter_output: T")
   1494     .Output("mkl_output: uint8")
   1495     .Output("mkl_filter_output: uint8")
   1496     .Attr("T: {half, float, double}")
   1497     .Attr("strides: list(int)")
   1498     .Attr("use_cudnn_on_gpu: bool = true")
   1499     .Attr(GetPaddingAttrString())
   1500     .Attr(GetConvnetDataFormatAttrString())
   1501     .SetShapeFn(shape_inference::Conv2DShape)
   1502     .Doc(R"doc(
   1503 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
   1504 
   1505 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1506 expected to invoke these operators.
   1507 )doc");
   1508 
   1509 REGISTER_OP("__MklDummyConv2DWithBias")
   1510     .Input("input: T")
   1511     .Input("filter: T")
   1512     .Input("bias: T")
   1513     .Output("output: T")
   1514     .Attr("T: {half, float, double}")
   1515     .Attr("strides: list(int)")
   1516     .Attr("use_cudnn_on_gpu: bool = true")
   1517     .Attr(GetPaddingAttrString())
   1518     .Attr(GetConvnetDataFormatAttrString())
   1519     .Doc(R"doc(
   1520 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
   1521 does not perform anything. It is just created as an intermediate output of
   1522 merging Conv2D and BiasAdd.
   1523 
   1524 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1525 expected to invoke these operators.
   1526 )doc");
   1527 
   1528 REGISTER_OP("_MklConv2DWithBias")
   1529     .Input("input: T")
   1530     .Input("filter: T")
   1531     .Input("bias: T")
   1532     .Input("mkl_input: uint8")
   1533     .Input("mkl_filter: uint8")
   1534     .Input("mkl_bias: uint8")
   1535     .Output("output: T")
   1536     .Output("filter_output: T")
   1537     .Output("mkl_output: uint8")
   1538     .Output("mkl_filter_output: uint8")
   1539     .Attr("T: {half, float, double}")
   1540     .Attr("strides: list(int)")
   1541     .Attr("use_cudnn_on_gpu: bool = true")
   1542     .Attr(GetPaddingAttrString())
   1543     .Attr(GetConvnetDataFormatAttrString())
   1544     .Doc(R"doc(
   1545 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
   1546 2D convolution and add Bias to the output of convolution.
   1547 
   1548 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1549 expected to invoke these operators.
   1550 )doc");
   1551 
   1552 REGISTER_OP("_MklConv2DBackpropFilter")
   1553     .Input("input: T")
   1554     .Input("filter_sizes: int32")
   1555     .Input("out_backprop: T")
   1556     .Input("mkl_input: uint8")
   1557     .Input("mkl_filter_size: uint8")
   1558     .Input("mkl_out_backprop: uint8")
   1559     .Output("output: T")
   1560     .Output("mkl_output: uint8")
   1561     .Attr("T: {half, float, double}")
   1562     .Attr("strides: list(int)")
   1563     .Attr("use_cudnn_on_gpu: bool = true")
   1564     .Attr(GetPaddingAttrString())
   1565     .Attr(GetConvnetDataFormatAttrString())
   1566     .SetShapeFn([](InferenceContext* c) {
   1567       ShapeHandle s;
   1568       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
   1569       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
   1570       c->set_output(0, s);
   1571       return Status::OK();
   1572     })
   1573     .Doc(R"doc(
   1574 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
   1575 gradients of convolution with respect to the filter.
   1576 
   1577 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1578 expected to invoke these operators.
   1579 )doc");
   1580 
   1581 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
   1582     .Input("input: T")
   1583     .Input("filter_sizes: int32")
   1584     .Input("out_backprop: T")
   1585     .Output("output: T")
   1586     .Output("bias_grad: T")
   1587     .Attr("T: {half, float, double}")
   1588     .Attr("strides: list(int)")
   1589     .Attr("use_cudnn_on_gpu: bool = true")
   1590     .Attr(GetPaddingAttrString())
   1591     .Attr(GetConvnetDataFormatAttrString())
   1592     .SetShapeFn([](InferenceContext* c) {
   1593       ShapeHandle input_shape;
   1594       // Fetch the data_format attribute, which may not exist.
   1595       string data_format;
   1596       Status s = c->GetAttr("data_format", &data_format);
   1597 
   1598       if (s.ok() && data_format == "NCHW") {
   1599         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   1600         c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
   1601       } else {
   1602         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   1603         c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
   1604       }
   1605       ShapeHandle sh;
   1606       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
   1607       TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
   1608       c->set_output(0, sh);
   1609       return Status::OK();
   1610     })
   1611     .Doc(R"doc(
   1612 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
   1613 for MKL. This node does not perform anything. It is just created as an
   1614 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
   1615 
   1616 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1617 expected to invoke these operators.
   1618 )doc");
   1619 
   1620 REGISTER_OP("_MklConv2DBackpropFilterWithBias")
   1621     .Input("input: T")
   1622     .Input("filter_sizes: int32")
   1623     .Input("out_backprop: T")
   1624     .Input("mkl_input: uint8")
   1625     .Input("mkl_filter_size: uint8")
   1626     .Input("mkl_out_backprop: uint8")
   1627     .Output("output: T")
   1628     .Output("bias_grad: T")
   1629     .Output("mkl_output: uint8")
   1630     .Output("mkl_bias_grad: uint8")
   1631     .Attr("T: {half, float, double}")
   1632     .Attr("strides: list(int)")
   1633     .Attr("use_cudnn_on_gpu: bool = true")
   1634     .Attr(GetPaddingAttrString())
   1635     .Attr(GetConvnetDataFormatAttrString())
   1636     .SetShapeFn([](InferenceContext* c) {
   1637       ShapeHandle input_shape;
   1638       // Fetch the data_format attribute, which may not exist.
   1639       string data_format;
   1640       Status s = c->GetAttr("data_format", &data_format);
   1641 
   1642       if (s.ok() && data_format == "NCHW") {
   1643         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   1644         c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
   1645       } else {
   1646         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
   1647         c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
   1648       }
   1649       ShapeHandle sh;
   1650       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
   1651       TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
   1652       c->set_output(0, sh);
   1653       return Status::OK();
   1654     })
   1655     .Doc(R"doc(
   1656 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
   1657 gradients of convolution with respect to the filter.
   1658 
   1659 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1660 expected to invoke these operators.
   1661 )doc");
   1662 
   1663 REGISTER_OP("_MklConv2DWithBiasBackpropBias")
   1664     .Input("out_backprop: T")
   1665     .Input("mkl_out_backprop: uint8")
   1666     .Output("output: T")
   1667     .Output("mkl_output: uint8")
   1668     .Attr("T: {half, float, double}")
   1669     .Attr("strides: list(int)")
   1670     .Attr(GetConvnetDataFormatAttrString())
   1671     .Doc(R"doc(
   1672 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
   1673 gradients of convolution with respect to the bias.
   1674 
   1675 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1676 expected to invoke these operators.
   1677 )doc");
   1678 
   1679 REGISTER_OP("_MklConv2DBackpropInput")
   1680     .Input("input_sizes: int32")
   1681     .Input("filter: T")
   1682     .Input("out_backprop: T")
   1683     .Input("mkl_input_sizes: uint8")
   1684     .Input("mkl_filter: uint8")
   1685     .Input("mkl_out_backprop: uint8")
   1686     .Output("output: T")
   1687     .Output("mkl_output: uint8")
   1688     .Attr("T: {half, float, double}")
   1689     .Attr("strides: list(int)")
   1690     .Attr("use_cudnn_on_gpu: bool = true")
   1691     .Attr(GetPaddingAttrString())
   1692     .Attr(GetConvnetDataFormatAttrString())
   1693     .SetShapeFn([](InferenceContext* c) {
   1694       ShapeHandle s;
   1695       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
   1696       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
   1697       c->set_output(0, s);
   1698       return Status::OK();
   1699     })
   1700     .Doc(R"doc(
   1701 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
   1702 gradients of convolution with respect to the input.
   1703 
   1704 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1705 expected to invoke these operators.
   1706 )doc");
   1707 
   1708 REGISTER_OP("_MklRelu")
   1709     .Input("features: T")
   1710     .Input("mkl_features: uint8")
   1711     .Output("activations: T")
   1712     .Output("mkl_activations: uint8")
   1713     .Attr("T: realnumbertype")
   1714     .SetShapeFn(shape_inference::UnchangedShape)
   1715     .Doc(R"doc(
   1716 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
   1717 
   1718 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1719 expected to invoke these operators.
   1720 )doc");
   1721 
   1722 REGISTER_OP("_MklReluGrad")
   1723     .Input("gradients: T")
   1724     .Input("features: T")
   1725     .Input("mkl_gradients: uint8")
   1726     .Input("mkl_features: uint8")
   1727     .Output("backprops: T")
   1728     .Output("mkl_backprops: uint8")
   1729     .Attr("T: realnumbertype")
   1730     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
   1731     .Doc(R"doc(
   1732 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
   1733 linear gradients for Relu operation.
   1734 
   1735 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1736 expected to invoke these operators.
   1737 )doc");
   1738 
   1739 REGISTER_OP("_MklElu")
   1740     .Input("features: T")
   1741     .Input("mkl_features: uint8")
   1742     .Output("activations: T")
   1743     .Output("mkl_activations: uint8")
   1744     .Attr("T: realnumbertype")
   1745     .SetShapeFn(shape_inference::UnchangedShape)
   1746     .Doc(R"doc(
   1747 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
   1748 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1749 expected to invoke these operators.
   1750 )doc");
   1751 
   1752 REGISTER_OP("_MklEluGrad")
   1753     .Input("gradients: T")
   1754     .Input("features: T")
   1755     .Input("mkl_gradients: uint8")
   1756     .Input("mkl_features: uint8")
   1757     .Output("backprops: T")
   1758     .Output("mkl_backprops: uint8")
   1759     .Attr("T: realnumbertype")
   1760     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
   1761     .Doc(R"doc(
   1762 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
   1763 gradients for Elu operation.
   1764 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1765 expected to invoke these operators.
   1766 )doc");
   1767 
   1768 REGISTER_OP("_MklSoftmax")
   1769     .Input("logits: T")
   1770     .Input("mkl_logits: uint8")
   1771     .Output("softmax: T")
   1772     .Output("mkl_softmax: uint8")
   1773     .Attr("T: {half, float, double}")
   1774     .SetShapeFn([](InferenceContext* c) {
   1775       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
   1776     })
   1777     .Doc(R"doc(
   1778 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
   1779 linear gradients for Relu operation.
   1780 )doc");
   1781 
   1782 REGISTER_OP("_MklTanh")
   1783     .Input("features: T")
   1784     .Input("mkl_features: uint8")
   1785     .Output("activations: T")
   1786     .Output("mkl_activations: uint8")
   1787     .Attr("T: realnumbertype")
   1788     .SetShapeFn(shape_inference::UnchangedShape)
   1789     .Doc(R"doc(
   1790 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
   1791 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1792 expected to invoke these operators.
   1793 )doc");
   1794 
   1795 REGISTER_OP("_MklTanhGrad")
   1796     .Input("gradients: T")
   1797     .Input("features: T")
   1798     .Input("mkl_gradients: uint8")
   1799     .Input("mkl_features: uint8")
   1800     .Output("backprops: T")
   1801     .Output("mkl_backprops: uint8")
   1802     .Attr("T: realnumbertype")
   1803     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
   1804     .Doc(R"doc(
   1805 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
   1806 gradients for Tanh operation.
   1807 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1808 expected to invoke these operators.
   1809 )doc");
   1810 
   1811 REGISTER_OP("_MklMaxPool")
   1812     .Attr("T: {float, half} = DT_FLOAT")
   1813     .Attr("ksize: list(int) >= 4")
   1814     .Attr("strides: list(int) >= 4")
   1815     .Attr(GetPaddingAttrString())
   1816     .Attr(GetConvnetDataFormatAttrString())
   1817     .Attr("workspace_enabled: bool = false")
   1818     .Input("input: T")
   1819     .Input("mkl_input: uint8")
   1820     .Output("output: T")
   1821 #ifdef INTEL_MKL_ML
   1822     .Output("workspace: T")
   1823 #else
   1824     .Output("workspace: uint8")
   1825 #endif
   1826     .Output("mkl_output: uint8")
   1827     .Output("mkl_workspace: uint8")
   1828     .SetShapeFn(shape_inference::MaxPoolShape)
   1829     .Doc(R"doc(
   1830 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
   1831 on the input.
   1832 
   1833 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1834 expected to invoke these operators.
   1835 )doc");
   1836 
   1837 REGISTER_OP("_MklMaxPoolGrad")
   1838     .Attr("T: {float, half} = DT_FLOAT")
   1839     .Attr("ksize: list(int) >= 4")
   1840     .Attr("strides: list(int) >= 4")
   1841     .Attr("workspace_enabled: bool = false")
   1842     .Attr(GetPaddingAttrString())
   1843     .Attr(GetConvnetDataFormatAttrString())
   1844     .Input("orig_input: T")
   1845     .Input("orig_output: T")
   1846     .Input("grad: T")
   1847 #ifdef INTEL_MKL_ML
   1848     .Input("workspace: T")
   1849 #else
   1850     .Input("workspace: uint8")
   1851 #endif
   1852     .Input("mkl_orig_input: uint8")
   1853     .Input("mkl_orig_output: uint8")
   1854     .Input("mkl_grad: uint8")
   1855     .Input("mkl_workspace: uint8")
   1856     .Output("output: T")
   1857     .Output("mkl_output: uint8")
   1858     .SetShapeFn([](InferenceContext* c) {
   1859       return UnchangedShapeWithRank(c, 4);
   1860     })
   1861     .Doc(R"doc(
   1862 MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of
   1863 MaxPool operator.
   1864 
   1865 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1866 expected to invoke these operators.
   1867 )doc");
   1868 
   1869 REGISTER_OP("_MklAvgPool")
   1870     .Input("value: T")
   1871     .Input("mkl_input: uint8")
   1872     .Output("output: T")
   1873     .Output("mkl_output: uint8")
   1874     .Attr("ksize: list(int) >= 4")
   1875     .Attr("strides: list(int) >= 4")
   1876     .Attr(GetPaddingAttrString())
   1877     .Attr(GetConvnetDataFormatAttrString())
   1878     .Attr("T: {float, half, double}")
   1879     .SetShapeFn(shape_inference::AvgPoolShape)
   1880     .Doc(R"doc(
   1881 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
   1882 on the input.
   1883 
   1884 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1885 expected to invoke these operators.
   1886 )doc");
   1887 
   1888 REGISTER_OP("_MklAvgPoolGrad")
   1889     .Input("orig_input_shape: int32")
   1890     .Input("grad: T")
   1891     .Input("mkl_orig_input: uint8")
   1892     .Input("mkl_grad: uint8")
   1893     .Output("output: T")
   1894     .Output("mkl_output: uint8")
   1895     .Attr("ksize: list(int) >= 4")
   1896     .Attr("strides: list(int) >= 4")
   1897     .Attr(GetPaddingAttrString())
   1898     .Attr(GetConvnetDataFormatAttrString())
   1899     .Attr("T: {float, half, double}")
   1900     .SetShapeFn([](InferenceContext* c) {
   1901       ShapeHandle s;
   1902       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
   1903       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
   1904       c->set_output(0, s);
   1905       return Status::OK();
   1906     })
   1907     .Doc(R"doc(
   1908 MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
   1909 of AvgPool function.
   1910 
   1911 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1912 expected to invoke these operators.
   1913 )doc");
   1914 
   1915 REGISTER_OP("_MklLRN")
   1916     .Input("input: T")
   1917     .Input("mkl_input: uint8")
   1918     .Output("output: T")
   1919 #ifdef INTEL_MKL_ML
   1920     .Output("workspace: T")
   1921 #else
   1922     .Output("workspace: uint8")
   1923 #endif
   1924     .Output("mkl_output: uint8")
   1925     .Output("mkl_workspace: uint8")
   1926     .Attr("depth_radius: int = 5")
   1927     .Attr("bias: float = 1.0")
   1928     .Attr("alpha: float = 1.0")
   1929     .Attr("beta: float = 0.5")
   1930     .Attr("workspace_enabled: bool = false")
   1931     .Attr("T: {float, half} = DT_FLOAT")
   1932     .SetShapeFn([](InferenceContext* c) {
   1933       return UnchangedShapeWithRank(c, 4);
   1934     })
   1935     .Doc(R"doc(
   1936 MKL version of LRN operator. Uses MKL DNN APIs to perform local response
   1937 normalization.
   1938 
   1939 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1940 expected to invoke these operators.
   1941 )doc");
   1942 
   1943 REGISTER_OP("_MklLRNGrad")
   1944     .Input("input_grads: T")
   1945     .Input("input_image: T")
   1946     .Input("output_image: T")
   1947 #ifdef INTEL_MKL_ML
   1948     .Input("workspace: T")
   1949 #else
   1950     .Input("workspace: uint8")
   1951 #endif
   1952     .Input("mkl_input_grads: uint8")
   1953     .Input("mkl_input_image: uint8")
   1954     .Input("mkl_output_image: uint8")
   1955     .Input("mkl_workspace: uint8")
   1956     .Output("output: T")
   1957     .Output("mkl_output: uint8")
   1958     .Attr("depth_radius: int = 5")
   1959     .Attr("bias: float = 1.0")
   1960     .Attr("alpha: float = 1.0")
   1961     .Attr("beta: float = 0.5")
   1962     .Attr("workspace_enabled: bool = false")
   1963     .Attr("T: {float, half} = DT_FLOAT")
   1964     .SetShapeFn([](InferenceContext* c) {
   1965       ShapeHandle s;
   1966       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
   1967       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
   1968       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
   1969       c->set_output(0, s);
   1970       return Status::OK();
   1971     })
   1972     .Doc(R"doc(
   1973 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
   1974 local response normalization.
   1975 
   1976 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   1977 expected to invoke these operators.
   1978 )doc");
   1979 
   1980 REGISTER_OP("_MklFusedBatchNorm")
   1981     .Input("x: T")
   1982     .Input("scale: T")
   1983     .Input("offset: T")
   1984     .Input("mean: T")
   1985     .Input("variance: T")
   1986     .Input("mkl_x: uint8")
   1987     .Input("mkl_scale: uint8")
   1988     .Input("mkl_offset: uint8")
   1989     .Input("mkl_mean: uint8")
   1990     .Input("mkl_variance: uint8")
   1991     .Output("y: T")
   1992     .Output("batch_mean: T")
   1993     .Output("batch_variance: T")
   1994     .Output("reserve_space_1: T")
   1995     .Output("reserve_space_2: T")
   1996     .Output("mkl_y: uint8")
   1997     .Output("mkl_batch_mean: uint8")
   1998     .Output("mkl_batch_variance: uint8")
   1999     .Output("mkl_reserve_space_1: uint8")
   2000     .Output("mkl_reserve_space_2: uint8")
   2001     .Attr("T: numbertype")
   2002     .Attr("epsilon: float = 0.0001")
   2003     .Attr("data_format: string = 'NHWC'")
   2004     .Attr("is_training: bool = true")
   2005     .SetShapeFn([](InferenceContext* c) {
   2006       ShapeHandle x;
   2007       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
   2008 
   2009       bool is_training;
   2010       c->GetAttr("is_training", &is_training);
   2011       int number_inputs = (is_training) ? 3 : 5;
   2012       string data_format;
   2013       c->GetAttr("data_format", &data_format);
   2014       DimensionHandle channel_dim =
   2015           (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
   2016 
   2017       // covers scale, offset, and if is_training is false, mean, variance
   2018       for (int i = 1; i < number_inputs; ++i) {
   2019         ShapeHandle vec;
   2020         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
   2021         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
   2022       }
   2023 
   2024       ShapeHandle y;
   2025       if (data_format == "NHWC") {
   2026         TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
   2027       } else {
   2028         TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
   2029       }
   2030       c->set_output(0, y);
   2031       ShapeHandle vector_shape = c->Vector(channel_dim);
   2032       c->set_output(1, vector_shape);
   2033       c->set_output(2, vector_shape);
   2034       c->set_output(3, vector_shape);
   2035       c->set_output(4, vector_shape);
   2036       return Status::OK();
   2037     })
   2038     .Doc(R"doc(
   2039 MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
   2040 batch normalization.
   2041 
   2042 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   2043 expected to invoke these operators.
   2044 )doc");
   2045 
   2046 REGISTER_OP("_MklFusedBatchNormGrad")
   2047     .Input("y_backprop: T")
   2048     .Input("x: T")
   2049     .Input("scale: T")
   2050     .Input("reserve_space_1: T")
   2051     .Input("reserve_space_2: T")
   2052     .Input("mkl_y_backprop: uint8")
   2053     .Input("mkl_x: uint8")
   2054     .Input("mkl_scale: uint8")
   2055     .Input("mkl_reserve_space_1: uint8")
   2056     .Input("mkl_reserve_space_2: uint8")
   2057     .Output("x_backprop: T")
   2058     .Output("scale_backprop: T")
   2059     .Output("offset_backprop: T")
   2060     .Output("reserve_space_3: T")
   2061     .Output("reserve_space_4: T")
   2062     .Output("mkl_x_backprop: uint8")
   2063     .Output("mkl_scale_backprop: uint8")
   2064     .Output("mkl_offset_backprop: uint8")
   2065     .Output("mkl_reserve_space_3: uint8")
   2066     .Output("mkl_reserve_space_4: uint8")
   2067     .Attr("T: numbertype")
   2068     .Attr("epsilon: float = 0.0001")
   2069     .Attr("data_format: string = 'NHWC'")
   2070     .Attr("is_training: bool = true")
   2071     .SetShapeFn([](InferenceContext* c) {
   2072       ShapeHandle y_backprop;
   2073       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
   2074       ShapeHandle x;
   2075       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
   2076 
   2077       bool is_training;
   2078       string data_format;
   2079       c->GetAttr("is_training", &is_training);
   2080       c->GetAttr("data_format", &data_format);
   2081       DimensionHandle channel_dim = (data_format == "NHWC")
   2082                                         ? c->Dim(y_backprop, 3)
   2083                                         : c->Dim(y_backprop, 1);
   2084       if (data_format == "NHWC") {
   2085         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
   2086       } else {
   2087         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
   2088       }
   2089 
   2090       // covers scale, mean (reserve_space_1), variance (reserve_space_2)
   2091       for (int i = 2; i < 5; ++i) {
   2092         ShapeHandle vec;
   2093         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
   2094         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
   2095       }
   2096 
   2097       ShapeHandle x_backprop;
   2098       if (data_format == "NHWC") {
   2099         TF_RETURN_IF_ERROR(
   2100             c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
   2101       } else {
   2102         TF_RETURN_IF_ERROR(
   2103             c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
   2104       }
   2105       c->set_output(0, x_backprop);
   2106       c->set_output(1, c->Vector(channel_dim));
   2107       c->set_output(2, c->Vector(channel_dim));
   2108       // Set the correct shapes for reserve_spaces
   2109       // so that gradients can be performed when
   2110       // the op is in a symbolic condition.
   2111       if (is_training) {
   2112         c->set_output(3, c->Vector(0));
   2113         c->set_output(4, c->Vector(0));
   2114       } else {
   2115         c->set_output(3, c->Vector(channel_dim));
   2116         c->set_output(4, c->Vector(channel_dim));
   2117       }
   2118       return Status::OK();
   2119     })
   2120     .Doc(R"doc(
   2121 MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
   2122 gradients for fused batch normalization.
   2123 
   2124 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   2125 expected to invoke these operators.
   2126 )doc");
   2127 
   2128 REGISTER_OP("_MklToTf")
   2129     .Input("input: T")
   2130     .Input("mkl_input: uint8")
   2131     .Output("output: T")
   2132     .Attr("T: {half, float, double}")
   2133     .Attr(GetConvnetDataFormatAttrString())
   2134     .Doc(R"doc(
   2135 MKL operator to convert a tensor from MKL layout to TensorFlow layout.
   2136 
   2137 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   2138 expected to invoke these operators.
   2139 )doc");
   2140 
   2141 REGISTER_OP("_MklInputConversion")
   2142     .Input("input_0: T")
   2143     .Input("input_1: T")
   2144     .Input("mkl_input_0: uint8")
   2145     .Input("mkl_input_1: uint8")
   2146     .Output("output_0: T")
   2147     .Output("output_1: T")
   2148     .Output("mkl_output_0: uint8")
   2149     .Output("mkl_output_1: uint8")
   2150     // All datatypes supported by element-wise ops
   2151     .Attr(
   2152         "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
   2153         "complex64, complex128}")
   2154     .Attr(GetConvnetDataFormatAttrString())
   2155     .Doc(R"doc(
   2156 MKL operator to process the inputs to an elementwise MKL op. Both inputs
   2157 need to be either in TF or in MKL format. This op is added before every
   2158 element-wise MKL op.
   2159 
   2160 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
   2161 expected to invoke these operators.
   2162 )doc");
   2163 #endif  // INTEL_MKL
   2164 
   2165 }  // namespace tensorflow
   2166