1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/numeric_op.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 #include "tensorflow/core/util/mirror_pad_mode.h" 21 #include "tensorflow/core/util/padding.h" 22 #include "tensorflow/core/util/tensor_format.h" 23 24 namespace tensorflow { 25 26 using shape_inference::DimensionHandle; 27 using shape_inference::InferenceContext; 28 using shape_inference::ShapeHandle; 29 30 namespace { 31 32 Status FractionalPoolShapeFn(InferenceContext* c) { 33 ShapeHandle input; 34 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); 35 36 std::vector<float> pooling_ratio; 37 TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio)); 38 if (pooling_ratio.size() != 4) { 39 return errors::InvalidArgument( 40 "pooling_ratio field must specify 4 dimensions"); 41 } 42 std::vector<DimensionHandle> output_dims; 43 for (int i = 0; i < 4; ++i) { 44 DimensionHandle d = c->Dim(input, i); 45 if (c->ValueKnown(d)) { 46 // This must match the same logic in the kernel function in 47 // core/kernels/fractional_max_pool_op.cc. 48 auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i])); 49 if (val < 0) { 50 return errors::InvalidArgument("Size computed for dim ", i, 51 " is negative: ", val); 52 } 53 output_dims.push_back(c->MakeDim(val)); 54 } else { 55 output_dims.push_back(c->UnknownDim()); 56 } 57 } 58 59 c->set_output(0, c->MakeShape(output_dims)); 60 c->set_output(1, c->Vector(output_dims[1])); 61 c->set_output(2, c->Vector(output_dims[2])); 62 return Status::OK(); 63 } 64 65 } // namespace 66 67 // -------------------------------------------------------------------------- 68 69 REGISTER_OP("AvgPool") 70 .Input("value: T") 71 .Output("output: T") 72 .Attr("ksize: list(int) >= 4") 73 .Attr("strides: list(int) >= 4") 74 .Attr(GetPaddingAttrString()) 75 .Attr(GetConvnetDataFormatAttrString()) 76 .Attr("T: {half, bfloat16, float, double}") 77 .SetShapeFn(shape_inference::AvgPoolShape); 78 79 REGISTER_OP("AvgPoolGrad") 80 .Input("orig_input_shape: int32") 81 .Input("grad: T") 82 .Output("output: T") 83 .Attr("ksize: list(int) >= 4") 84 .Attr("strides: list(int) >= 4") 85 .Attr(GetPaddingAttrString()) 86 .Attr(GetConvnetDataFormatAttrString()) 87 .Attr("T: {half, bfloat16, float, double}") 88 .SetShapeFn([](InferenceContext* c) { 89 ShapeHandle s; 90 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 91 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 92 c->set_output(0, s); 93 return Status::OK(); 94 }); 95 96 // -------------------------------------------------------------------------- 97 98 REGISTER_OP("BatchNormWithGlobalNormalization") 99 .Input("t: T") 100 .Input("m: T") 101 .Input("v: T") 102 .Input("beta: T") 103 .Input("gamma: T") 104 .Output("result: T") 105 .Attr("T: numbertype") 106 .Attr("variance_epsilon: float") 107 .Attr("scale_after_normalization: bool") 108 .Deprecated(9, "Use tf.nn.batch_normalization()") 109 .SetShapeFn([](InferenceContext* c) { 110 ShapeHandle input; 111 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); 112 113 DimensionHandle last_dim = c->Dim(input, 3); 114 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma 115 ShapeHandle vec; 116 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 117 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); 118 } 119 120 ShapeHandle out; 121 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); 122 c->set_output(0, out); 123 return Status::OK(); 124 }); 125 126 REGISTER_OP("BatchNormWithGlobalNormalizationGrad") 127 .Input("t: T") 128 .Input("m: T") 129 .Input("v: T") 130 .Input("gamma: T") 131 .Input("backprop: T") 132 .Output("dx: T") 133 .Output("dm: T") 134 .Output("dv: T") 135 .Output("db: T") 136 .Output("dg: T") 137 .Attr("T: numbertype") 138 .Attr("variance_epsilon: float") 139 .Attr("scale_after_normalization: bool") 140 .Deprecated(9, "Use tf.nn.batch_normalization()") 141 .SetShapeFn([](InferenceContext* c) { 142 ShapeHandle input; 143 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); 144 TF_RETURN_IF_ERROR( 145 c->Merge(input, c->input(4), &input)); // with backprop 146 147 DimensionHandle last_dim = c->Dim(input, 3); 148 for (int i = 1; i < 4; ++i) { // covers m, v, gamma 149 ShapeHandle vec; 150 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 151 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); 152 } 153 154 ShapeHandle dx; 155 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx)); 156 c->set_output(0, dx); 157 158 ShapeHandle vector_shape = c->Vector(last_dim); 159 c->set_output(1, vector_shape); 160 c->set_output(2, vector_shape); 161 c->set_output(3, vector_shape); 162 c->set_output(4, vector_shape); 163 return Status::OK(); 164 }); 165 166 // -------------------------------------------------------------------------- 167 168 REGISTER_OP("FusedBatchNorm") 169 .Input("x: T") 170 .Input("scale: T") 171 .Input("offset: T") 172 .Input("mean: T") 173 .Input("variance: T") 174 .Output("y: T") 175 .Output("batch_mean: T") 176 .Output("batch_variance: T") 177 .Output("reserve_space_1: T") 178 .Output("reserve_space_2: T") 179 .Attr("T: {float}") 180 .Attr("epsilon: float = 0.0001") 181 .Attr("data_format: string = 'NHWC'") 182 .Attr("is_training: bool = true") 183 .SetShapeFn(shape_inference::FusedBatchNormShape); 184 185 REGISTER_OP("FusedBatchNormV2") 186 .Input("x: T") 187 .Input("scale: U") 188 .Input("offset: U") 189 .Input("mean: U") 190 .Input("variance: U") 191 .Output("y: T") 192 .Output("batch_mean: U") 193 .Output("batch_variance: U") 194 .Output("reserve_space_1: U") 195 .Output("reserve_space_2: U") 196 .Attr("T: {half, bfloat16, float}") 197 .Attr("U: {float}") 198 .Attr("epsilon: float = 0.0001") 199 .Attr("data_format: string = 'NHWC'") 200 .Attr("is_training: bool = true") 201 .SetShapeFn(shape_inference::FusedBatchNormShape); 202 203 REGISTER_OP("FusedBatchNormGrad") 204 .Input("y_backprop: T") 205 .Input("x: T") 206 .Input("scale: T") 207 .Input("reserve_space_1: T") 208 .Input("reserve_space_2: T") 209 .Output("x_backprop: T") 210 .Output("scale_backprop: T") 211 .Output("offset_backprop: T") 212 .Output("reserve_space_3: T") 213 .Output("reserve_space_4: T") 214 .Attr("T: {float}") 215 .Attr("epsilon: float = 0.0001") 216 .Attr("data_format: string = 'NHWC'") 217 .Attr("is_training: bool = true") 218 .SetShapeFn(shape_inference::FusedBatchNormGradShape); 219 220 REGISTER_OP("FusedBatchNormGradV2") 221 .Input("y_backprop: T") 222 .Input("x: T") 223 .Input("scale: float") 224 .Input("reserve_space_1: U") 225 .Input("reserve_space_2: U") 226 .Output("x_backprop: T") 227 .Output("scale_backprop: U") 228 .Output("offset_backprop: U") 229 .Output("reserve_space_3: U") 230 .Output("reserve_space_4: U") 231 .Attr("T: {half, bfloat16, float}") 232 .Attr("U: {float}") 233 .Attr("epsilon: float = 0.0001") 234 .Attr("data_format: string = 'NHWC'") 235 .Attr("is_training: bool = true") 236 .SetShapeFn(shape_inference::FusedBatchNormGradShape); 237 238 // -------------------------------------------------------------------------- 239 240 REGISTER_OP("BiasAdd") 241 .Attr("T: numbertype") 242 .Input("value: T") 243 .Input("bias: T") 244 .Attr(GetConvnetDataFormatAttrString()) 245 .Output("output: T") 246 .SetShapeFn(shape_inference::BiasAddShape); 247 // -------------------------------------------------------------------------- 248 249 REGISTER_OP("BiasAddGrad") 250 .Attr("T: numbertype") 251 .Input("out_backprop: T") 252 .Attr(GetConvnetDataFormatAttrString()) 253 .Output("output: T") 254 .SetShapeFn(shape_inference::BiasAddGradShape); 255 // -------------------------------------------------------------------------- 256 257 REGISTER_OP("BiasAddV1") 258 .Attr("T: numbertype") 259 .Input("value: T") 260 .Input("bias: T") 261 .Output("output: T") 262 .SetShapeFn(shape_inference::BiasAddShape); 263 // -------------------------------------------------------------------------- 264 265 REGISTER_OP("Conv2D") 266 .Input("input: T") 267 .Input("filter: T") 268 .Output("output: T") 269 .Attr("T: {half, bfloat16, float}") 270 .Attr("strides: list(int)") 271 .Attr("use_cudnn_on_gpu: bool = true") 272 .Attr(GetPaddingAttrString()) 273 .Attr(GetConvnetDataFormatAttrString()) 274 .Attr("dilations: list(int) = [1, 1, 1, 1]") 275 .SetShapeFn(shape_inference::Conv2DShape); 276 277 REGISTER_OP("Conv2DBackpropInput") 278 .Input("input_sizes: int32") 279 .Input("filter: T") 280 .Input("out_backprop: T") 281 .Output("output: T") 282 .Attr("T: {half, bfloat16, float}") 283 .Attr("strides: list(int)") 284 .Attr("use_cudnn_on_gpu: bool = true") 285 .Attr(GetPaddingAttrString()) 286 .Attr(GetConvnetDataFormatAttrString()) 287 .Attr("dilations: list(int) = [1, 1, 1, 1]") 288 .SetShapeFn([](InferenceContext* c) { 289 ShapeHandle s; 290 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 291 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 292 c->set_output(0, s); 293 return Status::OK(); 294 }); 295 296 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a 297 // more general string attribute ('kernel_impl'?) that can be used to 298 // select among several possible implementations. 299 REGISTER_OP("Conv2DBackpropFilter") 300 .Input("input: T") 301 .Input("filter_sizes: int32") 302 .Input("out_backprop: T") 303 .Output("output: T") 304 .Attr("T: {half, bfloat16, float}") 305 .Attr("strides: list(int)") 306 .Attr("use_cudnn_on_gpu: bool = true") 307 .Attr(GetPaddingAttrString()) 308 .Attr(GetConvnetDataFormatAttrString()) 309 .Attr("dilations: list(int) = [1, 1, 1, 1]") 310 .SetShapeFn([](InferenceContext* c) { 311 ShapeHandle s; 312 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 313 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 314 c->set_output(0, s); 315 return Status::OK(); 316 }); 317 318 namespace { 319 320 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { 321 ShapeHandle input; 322 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); 323 324 ShapeHandle resized = input; 325 int paddings_index = 1; 326 int filter_index = 2; 327 if (has_resize) { 328 paddings_index = 2; 329 filter_index = 3; 330 331 ShapeHandle unused_size; 332 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size)); 333 334 const Tensor* size = c->input_tensor(1); 335 DimensionHandle new_height = c->UnknownDim(); 336 DimensionHandle new_width = c->UnknownDim(); 337 if (size != nullptr) { 338 new_height = c->MakeDim(size->flat<int32>()(0)); 339 new_width = c->MakeDim(size->flat<int32>()(1)); 340 } 341 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized)); 342 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized)); 343 } 344 345 ShapeHandle paddings; 346 TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings)); 347 TF_RETURN_IF_ERROR( 348 c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized)); 349 TF_RETURN_IF_ERROR( 350 c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings)); 351 352 const Tensor* paddings_t = c->input_tensor(paddings_index); 353 ShapeHandle padded; 354 if (paddings_t != nullptr) { 355 std::vector<DimensionHandle> output_dims; 356 for (int i = 0; i < 4; ++i) { 357 DimensionHandle dim = c->Dim(resized, i); 358 int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0)); 359 int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1)); 360 if (p0 < 0 || p1 < 0) { 361 return errors::InvalidArgument("Paddings must be non-negative"); 362 } 363 364 TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim)); 365 output_dims.push_back(dim); 366 } 367 padded = c->MakeShape(output_dims); 368 } else { 369 padded = c->UnknownShapeOfRank(4); 370 } 371 372 // Work out the convolution's effect with 'padded' as the input. 373 ShapeHandle filter; 374 TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter)); 375 std::vector<int32> strides; 376 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 377 if (strides.size() != 4) { 378 return errors::InvalidArgument( 379 "Operation requires the stride attribute to contain 4 values, but ", 380 "got: ", strides.size()); 381 } 382 383 int32 stride_rows = strides[1]; 384 int32 stride_cols = strides[2]; 385 386 DimensionHandle batch_size_dim = c->Dim(padded, 0); 387 DimensionHandle in_rows_dim = c->Dim(padded, 1); 388 DimensionHandle in_cols_dim = c->Dim(padded, 2); 389 DimensionHandle filter_rows_dim = c->Dim(filter, 0); 390 DimensionHandle filter_cols_dim = c->Dim(filter, 1); 391 DimensionHandle output_depth_dim = c->Dim(filter, 3); 392 393 DimensionHandle unused; 394 TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused)); 395 396 Padding padding; 397 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 398 399 DimensionHandle output_rows, output_cols; 400 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 401 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); 402 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 403 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); 404 405 ShapeHandle output_shape = c->MakeShape( 406 {batch_size_dim, output_rows, output_cols, output_depth_dim}); 407 c->set_output(0, output_shape); 408 return Status::OK(); 409 } 410 411 } // namespace 412 413 REGISTER_OP("DataFormatDimMap") 414 .Input("x: T") 415 .Output("y: T") 416 .Attr("T: {int32, int64} = DT_INT32") 417 .Attr("src_format: string = 'NHWC'") 418 .Attr("dst_format: string = 'NCHW'") 419 .SetShapeFn(shape_inference::UnchangedShape); 420 421 REGISTER_OP("DataFormatVecPermute") 422 .Input("x: T") 423 .Output("y: T") 424 .Attr("T: {int32, int64} = DT_INT32") 425 .Attr("src_format: string = 'NHWC'") 426 .Attr("dst_format: string = 'NCHW'") 427 .SetShapeFn(shape_inference::UnchangedShape); 428 429 REGISTER_OP("FusedResizeAndPadConv2D") 430 .Input("input: T") 431 .Input("size: int32") 432 .Input("paddings: int32") 433 .Input("filter: T") 434 .Output("output: T") 435 .Attr("T: {float}") 436 .Attr("resize_align_corners: bool = false") 437 .Attr(GetMirrorPadModeAttrString()) 438 .Attr("strides: list(int)") 439 .Attr(GetPaddingAttrString()) 440 .SetShapeFn([](InferenceContext* c) { 441 return CommonFusedConvCalculations(c, true /* has_resize */); 442 }); 443 444 REGISTER_OP("FusedPadConv2D") 445 .Input("input: T") 446 .Input("paddings: int32") 447 .Input("filter: T") 448 .Output("output: T") 449 .Attr("T: {float}") 450 .Attr(GetMirrorPadModeAttrString()) 451 .Attr("strides: list(int)") 452 .Attr(GetPaddingAttrString()) 453 .SetShapeFn([](InferenceContext* c) { 454 return CommonFusedConvCalculations(c, false /* has_resize */); 455 }); 456 457 // -------------------------------------------------------------------------- 458 459 REGISTER_OP("DepthwiseConv2dNative") 460 .Input("input: T") 461 .Input("filter: T") 462 .Output("output: T") 463 .Attr("T: {half, bfloat16, float, double}") 464 .Attr("strides: list(int)") 465 .Attr(GetPaddingAttrString()) 466 .Attr(GetConvnetDataFormatAttrString()) 467 .Attr("dilations: list(int) = [1, 1, 1, 1]") 468 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); 469 470 REGISTER_OP("DepthwiseConv2dNativeBackpropInput") 471 .Input("input_sizes: int32") 472 .Input("filter: T") 473 .Input("out_backprop: T") 474 .Output("output: T") 475 .Attr("T: {bfloat16, float, double}") 476 .Attr("strides: list(int)") 477 .Attr(GetPaddingAttrString()) 478 .Attr(GetConvnetDataFormatAttrString()) 479 .Attr("dilations: list(int) = [1, 1, 1, 1]") 480 .SetShapeFn([](InferenceContext* c) { 481 ShapeHandle s; 482 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 483 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 484 c->set_output(0, s); 485 return Status::OK(); 486 }); 487 488 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter") 489 .Input("input: T") 490 .Input("filter_sizes: int32") 491 .Input("out_backprop: T") 492 .Output("output: T") 493 .Attr("T: {bfloat16, float, double}") 494 .Attr("strides: list(int)") 495 .Attr(GetPaddingAttrString()) 496 .Attr(GetConvnetDataFormatAttrString()) 497 .Attr("dilations: list(int) = [1, 1, 1, 1]") 498 .SetShapeFn([](InferenceContext* c) { 499 ShapeHandle s; 500 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 501 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 502 c->set_output(0, s); 503 return Status::OK(); 504 }); 505 506 // -------------------------------------------------------------------------- 507 REGISTER_OP("Conv3D") 508 .Input("input: T") 509 .Input("filter: T") 510 .Output("output: T") 511 .Attr("T: {half, bfloat16, float, double}") 512 .Attr("strides: list(int) >= 5") 513 .Attr(GetPaddingAttrString()) 514 .Attr(GetConvnet3dDataFormatAttrString()) 515 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") 516 .SetShapeFn(shape_inference::Conv3DShape); 517 518 REGISTER_OP("Conv3DBackpropInput") 519 .Input("input: T") 520 .Input("filter: T") 521 .Input("out_backprop: T") 522 .Output("output: T") 523 .Attr("T: {half, float, double}") 524 .Attr("strides: list(int) >= 5") 525 .Attr(GetPaddingAttrString()) 526 .Deprecated(10, "Use Conv3DBackpropInputV2") 527 .SetShapeFn([](InferenceContext* c) { 528 return UnchangedShapeWithRank(c, 5); 529 }); 530 531 REGISTER_OP("Conv3DBackpropFilter") 532 .Input("input: T") 533 .Input("filter: T") 534 .Input("out_backprop: T") 535 .Output("output: T") 536 .Attr("T: {half, float, double}") 537 .Attr("strides: list(int) >= 5") 538 .Attr(GetPaddingAttrString()) 539 .Deprecated(10, "Use Conv3DBackpropFilterV2") 540 .SetShapeFn([](InferenceContext* c) { 541 ShapeHandle out; 542 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out)); 543 c->set_output(0, out); 544 return Status::OK(); 545 }); 546 547 REGISTER_OP("Conv3DBackpropInputV2") 548 .Input("input_sizes: int32") 549 .Input("filter: T") 550 .Input("out_backprop: T") 551 .Output("output: T") 552 .Attr("T: {half, bfloat16, float, double}") 553 .Attr("strides: list(int) >= 5") 554 .Attr(GetPaddingAttrString()) 555 .Attr(GetConvnet3dDataFormatAttrString()) 556 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") 557 .SetShapeFn([](InferenceContext* c) { 558 ShapeHandle s; 559 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 560 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); 561 c->set_output(0, s); 562 return Status::OK(); 563 }); 564 565 REGISTER_OP("Conv3DBackpropFilterV2") 566 .Input("input: T") 567 .Input("filter_sizes: int32") 568 .Input("out_backprop: T") 569 .Output("output: T") 570 .Attr("T: {half, bfloat16, float, double}") 571 .Attr("strides: list(int) >= 5") 572 .Attr(GetPaddingAttrString()) 573 .Attr(GetConvnet3dDataFormatAttrString()) 574 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") 575 .SetShapeFn([](InferenceContext* c) { 576 ShapeHandle s; 577 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 578 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); 579 c->set_output(0, s); 580 return Status::OK(); 581 }); 582 583 // -------------------------------------------------------------------------- 584 585 REGISTER_OP("AvgPool3D") 586 .Input("input: T") 587 .Output("output: T") 588 .Attr("ksize: list(int) >= 5") 589 .Attr("strides: list(int) >= 5") 590 .Attr(GetPaddingAttrString()) 591 .Attr(GetConvnet3dDataFormatAttrString()) 592 .Attr("T: {bfloat16, float, double}") 593 .SetShapeFn(shape_inference::Pool3DShape); 594 595 REGISTER_OP("AvgPool3DGrad") 596 .Input("orig_input_shape: int32") 597 .Input("grad: T") 598 .Output("output: T") 599 .Attr("ksize: list(int) >= 5") 600 .Attr("strides: list(int) >= 5") 601 .Attr(GetPaddingAttrString()) 602 .Attr(GetConvnet3dDataFormatAttrString()) 603 .Attr("T: {bfloat16, float, double}") 604 .SetShapeFn([](InferenceContext* c) { 605 ShapeHandle s; 606 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 607 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); 608 c->set_output(0, s); 609 return Status::OK(); 610 }); 611 612 // -------------------------------------------------------------------------- 613 614 REGISTER_OP("MaxPool3D") 615 .Input("input: T") 616 .Output("output: T") 617 .Attr("ksize: list(int) >= 5") 618 .Attr("strides: list(int) >= 5") 619 .Attr(GetPaddingAttrString()) 620 .Attr(GetConvnet3dDataFormatAttrString()) 621 .Attr("T: {bfloat16, float}") 622 .SetShapeFn(shape_inference::Pool3DShape); 623 624 REGISTER_OP("MaxPool3DGrad") 625 .Input("orig_input: TInput") 626 .Input("orig_output: TInput") 627 .Input("grad: T") 628 .Output("output: T") 629 .Attr("ksize: list(int) >= 5") 630 .Attr("strides: list(int) >= 5") 631 .Attr(GetPaddingAttrString()) 632 .Attr(GetConvnet3dDataFormatAttrString()) 633 .Attr("T: {bfloat16, float} = DT_FLOAT") 634 .Attr("TInput: {bfloat16, float} = DT_FLOAT") 635 .SetShapeFn([](InferenceContext* c) { 636 return UnchangedShapeWithRank(c, 5); 637 }); 638 639 REGISTER_OP("MaxPool3DGradGrad") 640 .Input("orig_input: T") 641 .Input("orig_output: T") 642 .Input("grad: T") 643 .Output("output: T") 644 .Attr("ksize: list(int) >= 5 ") 645 .Attr("strides: list(int) >= 5") 646 .Attr(GetPaddingAttrString()) 647 .Attr(GetConvnet3dDataFormatAttrString()) 648 .Attr("T: {float}") 649 .SetShapeFn([](InferenceContext* c) { 650 TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c)); 651 ShapeHandle unused; 652 // Validate 'orig_input' is the same shape as 'grad' 653 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); 654 // Validate 'orig_output' is same shape as 'output' 655 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); 656 return Status::OK(); 657 }); 658 659 // -------------------------------------------------------------------------- 660 661 REGISTER_OP("L2Loss") 662 .Input("t: T") 663 .Output("output: T") 664 .Attr("T: {half, bfloat16, float, double}") 665 .SetShapeFn(shape_inference::ScalarShape); 666 667 // -------------------------------------------------------------------------- 668 669 REGISTER_OP("LRN") 670 .Input("input: T") 671 .Output("output: T") 672 .Attr("depth_radius: int = 5") 673 .Attr("bias: float = 1.0") 674 .Attr("alpha: float = 1.0") 675 .Attr("beta: float = 0.5") 676 .Attr("T: {half, bfloat16, float} = DT_FLOAT") 677 .SetShapeFn([](InferenceContext* c) { 678 return UnchangedShapeWithRank(c, 4); 679 }); 680 681 REGISTER_OP("LRNGrad") 682 .Input("input_grads: T") 683 .Input("input_image: T") 684 .Input("output_image: T") 685 .Output("output: T") 686 .Attr("depth_radius: int = 5") 687 .Attr("bias: float = 1.0") 688 .Attr("alpha: float = 1.0") 689 .Attr("beta: float = 0.5") 690 .Attr("T: {half, bfloat16, float} = DT_FLOAT") 691 .SetShapeFn([](InferenceContext* c) { 692 ShapeHandle s; 693 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads 694 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image 695 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image 696 c->set_output(0, s); 697 return Status::OK(); 698 }); 699 700 // -------------------------------------------------------------------------- 701 702 REGISTER_OP("MaxPool") 703 .Attr( 704 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " 705 "uint16, qint8} = DT_FLOAT") 706 .Attr("ksize: list(int) >= 4") 707 .Attr("strides: list(int) >= 4") 708 .Attr(GetPaddingAttrString()) 709 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") 710 .Input("input: T") 711 .Output("output: T") 712 .SetShapeFn(shape_inference::MaxPoolShape); 713 714 REGISTER_OP("MaxPoolV2") 715 .Attr( 716 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " 717 "uint16, qint8} = DT_FLOAT") 718 .Attr(GetPaddingAttrString()) 719 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") 720 .Input("input: T") 721 .Input("ksize: int32") 722 .Input("strides: int32") 723 .Output("output: T") 724 .SetShapeFn([](InferenceContext* c) { 725 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3)); 726 return Status::OK(); 727 }); 728 729 REGISTER_OP("MaxPoolGrad") 730 .Attr("ksize: list(int) >= 4") 731 .Attr("strides: list(int) >= 4") 732 .Attr(GetPaddingAttrString()) 733 .Attr(GetConvnetDataFormatAttrString()) 734 .Input("orig_input: T") 735 .Input("orig_output: T") 736 .Input("grad: T") 737 .Output("output: T") 738 .Attr("T: realnumbertype = DT_FLOAT") 739 .SetShapeFn([](InferenceContext* c) { 740 return UnchangedShapeWithRank(c, 4); 741 }); 742 743 REGISTER_OP("MaxPoolGradV2") 744 .Attr(GetPaddingAttrString()) 745 .Attr(GetConvnetDataFormatAttrString()) 746 .Input("orig_input: T") 747 .Input("orig_output: T") 748 .Input("grad: T") 749 .Input("ksize: int32") 750 .Input("strides: int32") 751 .Output("output: T") 752 .Attr("T: realnumbertype = DT_FLOAT") 753 .SetShapeFn([](InferenceContext* c) { 754 return UnchangedShapeWithRank(c, 4); 755 }); 756 757 REGISTER_OP("MaxPoolGradGrad") 758 .Attr("ksize: list(int) >= 4") 759 .Attr("strides: list(int) >= 4") 760 .Attr(GetPaddingAttrString()) 761 .Attr(GetConvnetDataFormatAttrString()) 762 .Input("orig_input: T") 763 .Input("orig_output: T") 764 .Input("grad: T") 765 .Output("output: T") 766 .Attr("T: realnumbertype") 767 .SetShapeFn([](InferenceContext* c) { 768 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); 769 ShapeHandle unused; 770 // Validate 'orig_input' is the same shape as 'grad' 771 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); 772 // Validate 'orig_output' is same shape as 'output' 773 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); 774 return Status::OK(); 775 }); 776 777 REGISTER_OP("MaxPoolGradGradV2") 778 .Attr(GetPaddingAttrString()) 779 .Attr(GetConvnetDataFormatAttrString()) 780 .Input("orig_input: T") 781 .Input("orig_output: T") 782 .Input("grad: T") 783 .Input("ksize: int32") 784 .Input("strides: int32") 785 .Output("output: T") 786 .Attr("T: realnumbertype") 787 .SetShapeFn([](InferenceContext* c) { 788 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5)); 789 ShapeHandle unused; 790 // Validate 'orig_input' is the same shape as 'grad' 791 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); 792 // Validate 'orig_output' is same shape as 'output' 793 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); 794 return Status::OK(); 795 }); 796 797 REGISTER_OP("MaxPoolWithArgmax") 798 .Attr("ksize: list(int) >= 4") 799 .Attr("strides: list(int) >= 4") 800 .Attr("Targmax: {int32, int64} = DT_INT64") 801 .Attr(GetPaddingAttrString()) 802 .Input("input: T") 803 .Output("output: T") 804 .Output("argmax: Targmax") 805 .Attr("T: realnumbertype") 806 .SetShapeFn([](InferenceContext* c) { 807 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); 808 c->set_output(1, c->output(0)); 809 return Status::OK(); 810 }); 811 812 REGISTER_OP("MaxPoolGradWithArgmax") 813 .Attr("ksize: list(int) >= 4") 814 .Attr("strides: list(int) >= 4") 815 .Attr(GetPaddingAttrString()) 816 .Attr("Targmax: {int32, int64}") 817 .Input("input: T") 818 .Input("grad: T") 819 .Input("argmax: Targmax") 820 .Output("output: T") 821 .Attr("T: realnumbertype") 822 .SetShapeFn([](InferenceContext* c) { 823 return UnchangedShapeWithRank(c, 4); 824 }); 825 826 REGISTER_OP("MaxPoolGradGradWithArgmax") 827 .Attr("ksize: list(int) >= 4") 828 .Attr("strides: list(int) >= 4") 829 .Attr(GetPaddingAttrString()) 830 .Attr("Targmax: {int32, int64}") 831 .Input("input: T") 832 .Input("grad: T") 833 .Input("argmax: Targmax") 834 .Output("output: T") 835 .Attr("T: realnumbertype") 836 .SetShapeFn([](InferenceContext* c) { 837 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); 838 ShapeHandle unused; 839 // Validate 'orig_input' is the same shape as 'grad' 840 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused)); 841 // Validate 'argmax' is same shape as 'output' 842 TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused)); 843 return Status::OK(); 844 }); 845 846 // -------------------------------------------------------------------------- 847 848 REGISTER_OP("Dilation2D") 849 .Input("input: T") 850 .Input("filter: T") 851 .Output("output: T") 852 .Attr("T: realnumbertype") 853 .Attr("strides: list(int) >= 4") 854 .Attr("rates: list(int) >= 4") 855 .Attr(GetPaddingAttrString()) 856 .SetShapeFn([](InferenceContext* c) { 857 ShapeHandle input_shape; 858 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 859 ShapeHandle filter_shape; 860 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape)); 861 862 std::vector<int32> strides; 863 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 864 if (strides.size() != 4) { 865 return errors::InvalidArgument( 866 "Dilation2D requires the stride attribute to contain 4 values, but " 867 "got: ", 868 strides.size()); 869 } 870 871 std::vector<int32> rates; 872 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); 873 if (rates.size() != 4) { 874 return errors::InvalidArgument( 875 "Dilation2D requires the rates attribute to contain 4 values, but " 876 "got: ", 877 rates.size()); 878 } 879 880 int32 stride_rows = strides[1]; 881 int32 stride_cols = strides[2]; 882 883 int32 rate_rows = rates[1]; 884 int32 rate_cols = rates[2]; 885 886 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 887 DimensionHandle in_rows_dim = c->Dim(input_shape, 1); 888 DimensionHandle in_cols_dim = c->Dim(input_shape, 2); 889 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); 890 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); 891 DimensionHandle output_depth_dim = c->Dim(filter_shape, 2); 892 893 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) || 894 !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) { 895 ShapeHandle output_shape = 896 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, 897 InferenceContext::kUnknownDim, output_depth_dim}); 898 c->set_output(0, output_shape); 899 return Status::OK(); 900 } 901 DimensionHandle unused; 902 TF_RETURN_IF_ERROR( 903 c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused)); 904 905 auto in_rows = c->Value(in_rows_dim); 906 auto in_cols = c->Value(in_cols_dim); 907 auto filter_rows = c->Value(filter_rows_dim); 908 auto filter_cols = c->Value(filter_cols_dim); 909 auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1); 910 auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1); 911 912 Padding padding; 913 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 914 915 int64 output_rows, output_cols; 916 int64 padding_before, padding_after; 917 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 918 in_rows, filter_rows_eff, stride_rows, padding, &output_rows, 919 &padding_before, &padding_after)); 920 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( 921 in_cols, filter_cols_eff, stride_cols, padding, &output_cols, 922 &padding_before, &padding_after)); 923 924 ShapeHandle output_shape = c->MakeShape( 925 {batch_size_dim, output_rows, output_cols, output_depth_dim}); 926 c->set_output(0, output_shape); 927 return Status::OK(); 928 }); 929 930 REGISTER_OP("Dilation2DBackpropInput") 931 .Input("input: T") 932 .Input("filter: T") 933 .Input("out_backprop: T") 934 .Output("in_backprop: T") 935 .Attr("T: realnumbertype") 936 .Attr("strides: list(int) >= 4") 937 .Attr("rates: list(int) >= 4") 938 .Attr(GetPaddingAttrString()) 939 .SetShapeFn(shape_inference::UnchangedShape); 940 941 REGISTER_OP("Dilation2DBackpropFilter") 942 .Input("input: T") 943 .Input("filter: T") 944 .Input("out_backprop: T") 945 .Output("filter_backprop: T") 946 .Attr("T: realnumbertype") 947 .Attr("strides: list(int) >= 4") 948 .Attr("rates: list(int) >= 4") 949 .Attr(GetPaddingAttrString()) 950 .SetShapeFn([](InferenceContext* c) { 951 c->set_output(0, c->input(1)); 952 return Status::OK(); 953 }); 954 955 // -------------------------------------------------------------------------- 956 957 REGISTER_OP("Relu") 958 .Input("features: T") 959 .Output("activations: T") 960 .Attr("T: realnumbertype") 961 .SetShapeFn(shape_inference::UnchangedShape); 962 963 REGISTER_OP("ReluGrad") 964 .Input("gradients: T") 965 .Input("features: T") 966 .Output("backprops: T") 967 .Attr("T: realnumbertype") 968 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 969 970 REGISTER_OP("Relu6") 971 .Input("features: T") 972 .Output("activations: T") 973 .Attr("T: realnumbertype") 974 .SetShapeFn(shape_inference::UnchangedShape); 975 976 REGISTER_OP("Relu6Grad") 977 .Input("gradients: T") 978 .Input("features: T") 979 .Output("backprops: T") 980 .Attr("T: realnumbertype") 981 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 982 983 REGISTER_OP("Elu") 984 .Input("features: T") 985 .Output("activations: T") 986 .Attr("T: {half, bfloat16, float, double}") 987 .SetShapeFn(shape_inference::UnchangedShape); 988 989 REGISTER_OP("EluGrad") 990 .Input("gradients: T") 991 .Input("outputs: T") 992 .Output("backprops: T") 993 .Attr("T: {half, bfloat16, float, double}") 994 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 995 996 REGISTER_OP("Selu") 997 .Input("features: T") 998 .Output("activations: T") 999 .Attr("T: {half, bfloat16, float, double}") 1000 .SetShapeFn(shape_inference::UnchangedShape); 1001 1002 REGISTER_OP("SeluGrad") 1003 .Input("gradients: T") 1004 .Input("outputs: T") 1005 .Output("backprops: T") 1006 .Attr("T: {half, bfloat16, float, double}") 1007 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 1008 1009 REGISTER_OP("Softplus") 1010 .Input("features: T") 1011 .Output("activations: T") 1012 .Attr("T: realnumbertype") 1013 .SetShapeFn(shape_inference::UnchangedShape); 1014 1015 REGISTER_OP("SoftplusGrad") 1016 .Input("gradients: T") 1017 .Input("features: T") 1018 .Output("backprops: T") 1019 .Attr("T: realnumbertype") 1020 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 1021 1022 REGISTER_OP("Softsign") 1023 .Input("features: T") 1024 .Output("activations: T") 1025 .Attr("T: realnumbertype") 1026 .SetShapeFn(shape_inference::UnchangedShape); 1027 1028 REGISTER_OP("SoftsignGrad") 1029 .Input("gradients: T") 1030 .Input("features: T") 1031 .Output("backprops: T") 1032 .Attr("T: realnumbertype") 1033 .SetShapeFn(shape_inference::MergeBothInputsShapeFn); 1034 1035 // -------------------------------------------------------------------------- 1036 1037 REGISTER_OP("Softmax") 1038 .Input("logits: T") 1039 .Output("softmax: T") 1040 .Attr("T: {half, bfloat16, float, double}") 1041 .SetShapeFn([](InferenceContext* c) { 1042 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); 1043 }); 1044 1045 // -------------------------------------------------------------------------- 1046 1047 REGISTER_OP("LogSoftmax") 1048 .Input("logits: T") 1049 .Output("logsoftmax: T") 1050 .Attr("T: {half, bfloat16, float, double}") 1051 .SetShapeFn([](InferenceContext* c) { 1052 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); 1053 }); 1054 1055 // -------------------------------------------------------------------------- 1056 1057 REGISTER_OP("SoftmaxCrossEntropyWithLogits") 1058 .Input("features: T") 1059 .Input("labels: T") 1060 .Output("loss: T") 1061 .Output("backprop: T") 1062 .Attr("T: {half, bfloat16, float, double}") 1063 .SetShapeFn([](InferenceContext* c) { 1064 ShapeHandle input; 1065 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); 1066 TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input)); 1067 1068 DimensionHandle batch_size = c->Dim(input, 0); 1069 c->set_output(0, c->Vector(batch_size)); 1070 c->set_output(1, input); 1071 return Status::OK(); 1072 }); 1073 1074 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits") 1075 .Input("features: T") 1076 .Input("labels: Tlabels") 1077 .Output("loss: T") 1078 .Output("backprop: T") 1079 .Attr("T: {half, bfloat16, float, double}") 1080 .Attr("Tlabels: {int32, int64} = DT_INT64") 1081 .SetShapeFn([](InferenceContext* c) { 1082 ShapeHandle features; 1083 ShapeHandle labels; 1084 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features)); 1085 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels)); 1086 1087 DimensionHandle batch_size; 1088 TF_RETURN_IF_ERROR( 1089 c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size)); 1090 TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features)); 1091 1092 c->set_output(0, c->Vector(batch_size)); 1093 c->set_output(1, features); 1094 return Status::OK(); 1095 }); 1096 1097 // -------------------------------------------------------------------------- 1098 1099 REGISTER_OP("InTopK") 1100 .Input("predictions: float") 1101 .Input("targets: T") 1102 .Output("precision: bool") 1103 .Attr("k: int") 1104 .Attr("T: {int32, int64} = DT_INT32") 1105 .SetShapeFn([](InferenceContext* c) { 1106 ShapeHandle predictions; 1107 ShapeHandle targets; 1108 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); 1109 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); 1110 DimensionHandle batch_size; 1111 TF_RETURN_IF_ERROR( 1112 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); 1113 c->set_output(0, c->Vector(batch_size)); 1114 return Status::OK(); 1115 }); 1116 1117 // This is the same as `InTopK`, but takes `k` as in input rather than an attr. 1118 REGISTER_OP("InTopKV2") 1119 .Input("predictions: float") 1120 .Input("targets: T") 1121 .Input("k: T") 1122 .Output("precision: bool") 1123 .Attr("T: {int32, int64} = DT_INT32") 1124 .SetShapeFn([](InferenceContext* c) { 1125 ShapeHandle predictions; 1126 ShapeHandle targets; 1127 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); 1128 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); 1129 DimensionHandle batch_size; 1130 TF_RETURN_IF_ERROR( 1131 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); 1132 c->set_output(0, c->Vector(batch_size)); 1133 return Status::OK(); 1134 }); 1135 1136 namespace { 1137 1138 Status TopKShapeFn(InferenceContext* c) { 1139 ShapeHandle input; 1140 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 1141 1142 // Get the k value, either from input tensor or attribute. 1143 DimensionHandle k_dim; 1144 if (c->num_inputs() >= 2) { 1145 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim)); 1146 } else { 1147 int32 k; 1148 TF_RETURN_IF_ERROR(c->GetAttr("k", &k)); 1149 if (k < 0) { 1150 return errors::InvalidArgument("Need k >= 0, got ", k); 1151 } 1152 k_dim = c->MakeDim(k); 1153 } 1154 1155 DimensionHandle last_dim = c->Dim(input, -1); 1156 if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) && 1157 c->Value(last_dim) < c->Value(k_dim)) { 1158 return errors::InvalidArgument( 1159 "input must have last dimension >= k = ", c->Value(k_dim), " but is ", 1160 c->Value(last_dim)); 1161 } 1162 1163 // Replace last_dim with k_dim. 1164 ShapeHandle s; 1165 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s)); 1166 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s)); 1167 c->set_output(0, s); 1168 c->set_output(1, s); 1169 return Status::OK(); 1170 } 1171 1172 } // namespace 1173 1174 REGISTER_OP("TopK") 1175 .Input("input: T") 1176 .Output("values: T") 1177 .Output("indices: int32") 1178 .Attr("k: int >= 0") 1179 .Attr("sorted: bool = true") 1180 .Attr("T: realnumbertype") 1181 .Deprecated(7, "Use TopKV2 instead") 1182 .SetShapeFn(TopKShapeFn); 1183 1184 // This is the same as `TopK`, but takes `k` as in input rather than an attr. 1185 REGISTER_OP("TopKV2") 1186 .Input("input: T") 1187 .Input("k: int32") 1188 .Output("values: T") 1189 .Output("indices: int32") 1190 .Attr("sorted: bool = true") 1191 .Attr("T: realnumbertype") 1192 .SetShapeFn(TopKShapeFn); 1193 1194 // -------------------------------------------------------------------------- 1195 1196 REGISTER_OP("NthElement") 1197 .Input("input: T") 1198 .Input("n: int32") 1199 .Output("values: T") 1200 .Attr("reverse: bool = false") 1201 .Attr("T: realnumbertype") 1202 .SetShapeFn([](InferenceContext* c) { 1203 ShapeHandle input; 1204 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 1205 1206 // Get the n value from input tensor, and make sure which is a scalar. 1207 DimensionHandle n_dim; 1208 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim)); 1209 1210 // The last dimension of input tensor must be greater than N. 1211 DimensionHandle last_dim = c->Dim(input, -1); 1212 if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) && 1213 c->Value(last_dim) <= c->Value(n_dim)) { 1214 return errors::InvalidArgument( 1215 "Input must have last dimension > n = ", c->Value(n_dim), 1216 " but is ", c->Value(last_dim)); 1217 } 1218 1219 // Reduce last_dim for output tensor 1220 ShapeHandle s; 1221 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s)); 1222 c->set_output(0, s); 1223 return Status::OK(); 1224 }); 1225 1226 // -------------------------------------------------------------------------- 1227 1228 REGISTER_OP("FractionalMaxPool") 1229 .Input("value: T") 1230 .Output("output: T") 1231 .Output("row_pooling_sequence: int64") 1232 .Output("col_pooling_sequence: int64") 1233 .Attr("pooling_ratio: list(float) >=4") 1234 .Attr("pseudo_random: bool = false") 1235 .Attr("overlapping: bool = false") 1236 .Attr("deterministic: bool = false") 1237 .Attr("seed: int = 0") 1238 .Attr("seed2: int = 0") 1239 .Attr("T: {float, double, int32, int64}") 1240 .SetShapeFn(FractionalPoolShapeFn); 1241 1242 REGISTER_OP("FractionalMaxPoolGrad") 1243 .Input("orig_input: T") 1244 .Input("orig_output: T") 1245 .Input("out_backprop: T") 1246 .Input("row_pooling_sequence: int64") 1247 .Input("col_pooling_sequence: int64") 1248 .Output("output: T") 1249 .Attr("overlapping: bool = false") 1250 .Attr("T: {float, double, int32, int64}") 1251 .SetShapeFn([](InferenceContext* c) { 1252 return shape_inference::UnchangedShapeWithRank(c, 4); 1253 }); 1254 1255 // -------------------------------------------------------------------------- 1256 1257 REGISTER_OP("FractionalAvgPool") 1258 .Input("value: T") 1259 .Output("output: T") 1260 .Output("row_pooling_sequence: int64") 1261 .Output("col_pooling_sequence: int64") 1262 .Attr("pooling_ratio: list(float) >=4") 1263 .Attr("pseudo_random: bool = false") 1264 .Attr("overlapping: bool = false") 1265 .Attr("deterministic: bool = false") 1266 .Attr("seed: int = 0") 1267 .Attr("seed2: int = 0") 1268 .Attr("T: {float, double, int32, int64}") 1269 .SetShapeFn(FractionalPoolShapeFn); 1270 1271 REGISTER_OP("FractionalAvgPoolGrad") 1272 .Input("orig_input_tensor_shape: int64") 1273 .Input("out_backprop: T") 1274 .Input("row_pooling_sequence: int64") 1275 .Input("col_pooling_sequence: int64") 1276 .Output("output: T") 1277 .Attr("overlapping: bool = false") 1278 .Attr("T: {float, double, int32, int64}") 1279 .SetShapeFn([](InferenceContext* c) { 1280 if (c->input_tensor(0) != nullptr) { 1281 ShapeHandle out; 1282 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 1283 c->set_output(0, out); 1284 } else { 1285 c->set_output(0, c->UnknownShapeOfRank(4)); 1286 } 1287 return Status::OK(); 1288 }); 1289 1290 REGISTER_OP("QuantizedAvgPool") 1291 .Input("input: T") 1292 .Input("min_input: float") 1293 .Input("max_input: float") 1294 .Output("output: T") 1295 .Output("min_output: float") 1296 .Output("max_output: float") 1297 .Attr("T: quantizedtype") 1298 .Attr("ksize: list(int)") 1299 .Attr("strides: list(int)") 1300 .Attr(GetPaddingAttrString()) 1301 .SetShapeFn([](InferenceContext* c) { 1302 TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c)); 1303 ShapeHandle unused; 1304 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1305 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1306 c->set_output(1, c->Scalar()); 1307 c->set_output(2, c->Scalar()); 1308 return Status::OK(); 1309 }); 1310 1311 REGISTER_OP("QuantizedBiasAdd") 1312 .Input("input: T1") 1313 .Input("bias: T2") 1314 .Input("min_input: float") 1315 .Input("max_input: float") 1316 .Input("min_bias: float") 1317 .Input("max_bias: float") 1318 .Output("output: out_type") 1319 .Output("min_out: float") 1320 .Output("max_out: float") 1321 .Attr("T1: quantizedtype") 1322 .Attr("T2: quantizedtype") 1323 .Attr("out_type: quantizedtype") 1324 .SetShapeFn([](InferenceContext* c) { 1325 TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c)); 1326 ShapeHandle unused; 1327 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1328 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1329 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1330 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 1331 c->set_output(1, c->Scalar()); 1332 c->set_output(2, c->Scalar()); 1333 return Status::OK(); 1334 }); 1335 1336 REGISTER_OP("QuantizedConv2D") 1337 .Input("input: Tinput") 1338 .Input("filter: Tfilter") 1339 .Input("min_input: float") 1340 .Input("max_input: float") 1341 .Input("min_filter: float") 1342 .Input("max_filter: float") 1343 .Output("output: out_type") 1344 .Output("min_output: float") 1345 .Output("max_output: float") 1346 .Attr("Tinput: quantizedtype") 1347 .Attr("Tfilter: quantizedtype") 1348 .Attr("out_type: quantizedtype = DT_QINT32") 1349 .Attr("strides: list(int)") 1350 .Attr(GetPaddingAttrString()) 1351 .Attr("dilations: list(int) = [1, 1, 1, 1]") 1352 .SetShapeFn([](InferenceContext* c) { 1353 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); 1354 ShapeHandle unused; 1355 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1356 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1357 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1358 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 1359 c->set_output(1, c->Scalar()); 1360 c->set_output(2, c->Scalar()); 1361 return Status::OK(); 1362 }); 1363 1364 REGISTER_OP("QuantizedMaxPool") 1365 .Input("input: T") 1366 .Input("min_input: float") 1367 .Input("max_input: float") 1368 .Output("output: T") 1369 .Output("min_output: float") 1370 .Output("max_output: float") 1371 .Attr("T: quantizedtype") 1372 .Attr("ksize: list(int)") 1373 .Attr("strides: list(int)") 1374 .Attr(GetPaddingAttrString()) 1375 .SetShapeFn([](InferenceContext* c) { 1376 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); 1377 ShapeHandle unused; 1378 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1379 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1380 c->set_output(1, c->Scalar()); 1381 c->set_output(2, c->Scalar()); 1382 return Status::OK(); 1383 }); 1384 1385 REGISTER_OP("QuantizedRelu") 1386 .Input("features: Tinput") 1387 .Input("min_features: float") 1388 .Input("max_features: float") 1389 .Output("activations: out_type") 1390 .Output("min_activations: float") 1391 .Output("max_activations: float") 1392 .Attr("Tinput: quantizedtype") 1393 .Attr("out_type: quantizedtype = DT_QUINT8") 1394 .SetShapeFn([](InferenceContext* c) { 1395 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1396 ShapeHandle unused; 1397 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1398 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1399 c->set_output(1, c->Scalar()); 1400 c->set_output(2, c->Scalar()); 1401 return Status::OK(); 1402 }); 1403 1404 REGISTER_OP("QuantizedRelu6") 1405 .Input("features: Tinput") 1406 .Input("min_features: float") 1407 .Input("max_features: float") 1408 .Output("activations: out_type") 1409 .Output("min_activations: float") 1410 .Output("max_activations: float") 1411 .Attr("Tinput: quantizedtype") 1412 .Attr("out_type: quantizedtype = DT_QUINT8") 1413 .SetShapeFn([](InferenceContext* c) { 1414 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1415 ShapeHandle unused; 1416 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1417 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1418 c->set_output(1, c->Scalar()); 1419 c->set_output(2, c->Scalar()); 1420 return Status::OK(); 1421 }); 1422 1423 REGISTER_OP("QuantizedReluX") 1424 .Input("features: Tinput") 1425 .Input("max_value: float") 1426 .Input("min_features: float") 1427 .Input("max_features: float") 1428 .Output("activations: out_type") 1429 .Output("min_activations: float") 1430 .Output("max_activations: float") 1431 .Attr("Tinput: quantizedtype") 1432 .Attr("out_type: quantizedtype = DT_QUINT8") 1433 .SetShapeFn([](InferenceContext* c) { 1434 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1435 ShapeHandle unused; 1436 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1437 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1438 c->set_output(1, c->Scalar()); 1439 c->set_output(2, c->Scalar()); 1440 return Status::OK(); 1441 }); 1442 1443 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization") 1444 .Input("t: Tinput") 1445 .Input("t_min: float") 1446 .Input("t_max: float") 1447 .Input("m: Tinput") 1448 .Input("m_min: float") 1449 .Input("m_max: float") 1450 .Input("v: Tinput") 1451 .Input("v_min: float") 1452 .Input("v_max: float") 1453 .Input("beta: Tinput") 1454 .Input("beta_min: float") 1455 .Input("beta_max: float") 1456 .Input("gamma: Tinput") 1457 .Input("gamma_min: float") 1458 .Input("gamma_max: float") 1459 .Output("result: out_type") 1460 .Output("result_min: float") 1461 .Output("result_max: float") 1462 .Attr("Tinput: quantizedtype") 1463 .Attr("out_type: quantizedtype") 1464 .Attr("variance_epsilon: float") 1465 .Attr("scale_after_normalization: bool") 1466 .SetShapeFn([](InferenceContext* c) { 1467 ShapeHandle input; 1468 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); 1469 1470 DimensionHandle last_dim = c->Dim(input, 3); 1471 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma 1472 ShapeHandle vec; 1473 TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec)); 1474 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); 1475 } 1476 1477 ShapeHandle out; 1478 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); 1479 c->set_output(0, out); 1480 c->set_output(1, c->Scalar()); 1481 c->set_output(2, c->Scalar()); 1482 1483 return Status::OK(); 1484 }); 1485 1486 #ifdef INTEL_MKL 1487 REGISTER_OP("_MklConv2D") 1488 .Input("input: T") 1489 .Input("filter: T") 1490 .Input("mkl_input: uint8") 1491 .Input("mkl_filter: uint8") 1492 .Output("output: T") 1493 .Output("filter_output: T") 1494 .Output("mkl_output: uint8") 1495 .Output("mkl_filter_output: uint8") 1496 .Attr("T: {half, float, double}") 1497 .Attr("strides: list(int)") 1498 .Attr("use_cudnn_on_gpu: bool = true") 1499 .Attr(GetPaddingAttrString()) 1500 .Attr(GetConvnetDataFormatAttrString()) 1501 .SetShapeFn(shape_inference::Conv2DShape) 1502 .Doc(R"doc( 1503 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution. 1504 1505 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1506 expected to invoke these operators. 1507 )doc"); 1508 1509 REGISTER_OP("__MklDummyConv2DWithBias") 1510 .Input("input: T") 1511 .Input("filter: T") 1512 .Input("bias: T") 1513 .Output("output: T") 1514 .Attr("T: {half, float, double}") 1515 .Attr("strides: list(int)") 1516 .Attr("use_cudnn_on_gpu: bool = true") 1517 .Attr(GetPaddingAttrString()) 1518 .Attr(GetConvnetDataFormatAttrString()) 1519 .Doc(R"doc( 1520 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node 1521 does not perform anything. It is just created as an intermediate output of 1522 merging Conv2D and BiasAdd. 1523 1524 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1525 expected to invoke these operators. 1526 )doc"); 1527 1528 REGISTER_OP("_MklConv2DWithBias") 1529 .Input("input: T") 1530 .Input("filter: T") 1531 .Input("bias: T") 1532 .Input("mkl_input: uint8") 1533 .Input("mkl_filter: uint8") 1534 .Input("mkl_bias: uint8") 1535 .Output("output: T") 1536 .Output("filter_output: T") 1537 .Output("mkl_output: uint8") 1538 .Output("mkl_filter_output: uint8") 1539 .Attr("T: {half, float, double}") 1540 .Attr("strides: list(int)") 1541 .Attr("use_cudnn_on_gpu: bool = true") 1542 .Attr(GetPaddingAttrString()) 1543 .Attr(GetConvnetDataFormatAttrString()) 1544 .Doc(R"doc( 1545 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform 1546 2D convolution and add Bias to the output of convolution. 1547 1548 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1549 expected to invoke these operators. 1550 )doc"); 1551 1552 REGISTER_OP("_MklConv2DBackpropFilter") 1553 .Input("input: T") 1554 .Input("filter_sizes: int32") 1555 .Input("out_backprop: T") 1556 .Input("mkl_input: uint8") 1557 .Input("mkl_filter_size: uint8") 1558 .Input("mkl_out_backprop: uint8") 1559 .Output("output: T") 1560 .Output("mkl_output: uint8") 1561 .Attr("T: {half, float, double}") 1562 .Attr("strides: list(int)") 1563 .Attr("use_cudnn_on_gpu: bool = true") 1564 .Attr(GetPaddingAttrString()) 1565 .Attr(GetConvnetDataFormatAttrString()) 1566 .SetShapeFn([](InferenceContext* c) { 1567 ShapeHandle s; 1568 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); 1569 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 1570 c->set_output(0, s); 1571 return Status::OK(); 1572 }) 1573 .Doc(R"doc( 1574 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the 1575 gradients of convolution with respect to the filter. 1576 1577 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1578 expected to invoke these operators. 1579 )doc"); 1580 1581 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias") 1582 .Input("input: T") 1583 .Input("filter_sizes: int32") 1584 .Input("out_backprop: T") 1585 .Output("output: T") 1586 .Output("bias_grad: T") 1587 .Attr("T: {half, float, double}") 1588 .Attr("strides: list(int)") 1589 .Attr("use_cudnn_on_gpu: bool = true") 1590 .Attr(GetPaddingAttrString()) 1591 .Attr(GetConvnetDataFormatAttrString()) 1592 .SetShapeFn([](InferenceContext* c) { 1593 ShapeHandle input_shape; 1594 // Fetch the data_format attribute, which may not exist. 1595 string data_format; 1596 Status s = c->GetAttr("data_format", &data_format); 1597 1598 if (s.ok() && data_format == "NCHW") { 1599 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 1600 c->set_output(1, c->Vector(c->Dim(input_shape, -3))); 1601 } else { 1602 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 1603 c->set_output(1, c->Vector(c->Dim(input_shape, -1))); 1604 } 1605 ShapeHandle sh; 1606 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); 1607 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); 1608 c->set_output(0, sh); 1609 return Status::OK(); 1610 }) 1611 .Doc(R"doc( 1612 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator 1613 for MKL. This node does not perform anything. It is just created as an 1614 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad. 1615 1616 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1617 expected to invoke these operators. 1618 )doc"); 1619 1620 REGISTER_OP("_MklConv2DBackpropFilterWithBias") 1621 .Input("input: T") 1622 .Input("filter_sizes: int32") 1623 .Input("out_backprop: T") 1624 .Input("mkl_input: uint8") 1625 .Input("mkl_filter_size: uint8") 1626 .Input("mkl_out_backprop: uint8") 1627 .Output("output: T") 1628 .Output("bias_grad: T") 1629 .Output("mkl_output: uint8") 1630 .Output("mkl_bias_grad: uint8") 1631 .Attr("T: {half, float, double}") 1632 .Attr("strides: list(int)") 1633 .Attr("use_cudnn_on_gpu: bool = true") 1634 .Attr(GetPaddingAttrString()) 1635 .Attr(GetConvnetDataFormatAttrString()) 1636 .SetShapeFn([](InferenceContext* c) { 1637 ShapeHandle input_shape; 1638 // Fetch the data_format attribute, which may not exist. 1639 string data_format; 1640 Status s = c->GetAttr("data_format", &data_format); 1641 1642 if (s.ok() && data_format == "NCHW") { 1643 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 1644 c->set_output(1, c->Vector(c->Dim(input_shape, -3))); 1645 } else { 1646 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 1647 c->set_output(1, c->Vector(c->Dim(input_shape, -1))); 1648 } 1649 ShapeHandle sh; 1650 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); 1651 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); 1652 c->set_output(0, sh); 1653 return Status::OK(); 1654 }) 1655 .Doc(R"doc( 1656 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the 1657 gradients of convolution with respect to the filter. 1658 1659 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1660 expected to invoke these operators. 1661 )doc"); 1662 1663 REGISTER_OP("_MklConv2DWithBiasBackpropBias") 1664 .Input("out_backprop: T") 1665 .Input("mkl_out_backprop: uint8") 1666 .Output("output: T") 1667 .Output("mkl_output: uint8") 1668 .Attr("T: {half, float, double}") 1669 .Attr("strides: list(int)") 1670 .Attr(GetConvnetDataFormatAttrString()) 1671 .Doc(R"doc( 1672 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the 1673 gradients of convolution with respect to the bias. 1674 1675 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1676 expected to invoke these operators. 1677 )doc"); 1678 1679 REGISTER_OP("_MklConv2DBackpropInput") 1680 .Input("input_sizes: int32") 1681 .Input("filter: T") 1682 .Input("out_backprop: T") 1683 .Input("mkl_input_sizes: uint8") 1684 .Input("mkl_filter: uint8") 1685 .Input("mkl_out_backprop: uint8") 1686 .Output("output: T") 1687 .Output("mkl_output: uint8") 1688 .Attr("T: {half, float, double}") 1689 .Attr("strides: list(int)") 1690 .Attr("use_cudnn_on_gpu: bool = true") 1691 .Attr(GetPaddingAttrString()) 1692 .Attr(GetConvnetDataFormatAttrString()) 1693 .SetShapeFn([](InferenceContext* c) { 1694 ShapeHandle s; 1695 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 1696 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 1697 c->set_output(0, s); 1698 return Status::OK(); 1699 }) 1700 .Doc(R"doc( 1701 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the 1702 gradients of convolution with respect to the input. 1703 1704 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1705 expected to invoke these operators. 1706 )doc"); 1707 1708 REGISTER_OP("_MklRelu") 1709 .Input("features: T") 1710 .Input("mkl_features: uint8") 1711 .Output("activations: T") 1712 .Output("mkl_activations: uint8") 1713 .Attr("T: realnumbertype") 1714 .SetShapeFn(shape_inference::UnchangedShape) 1715 .Doc(R"doc( 1716 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator. 1717 1718 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1719 expected to invoke these operators. 1720 )doc"); 1721 1722 REGISTER_OP("_MklReluGrad") 1723 .Input("gradients: T") 1724 .Input("features: T") 1725 .Input("mkl_gradients: uint8") 1726 .Input("mkl_features: uint8") 1727 .Output("backprops: T") 1728 .Output("mkl_backprops: uint8") 1729 .Attr("T: realnumbertype") 1730 .SetShapeFn(shape_inference::MergeBothInputsShapeFn) 1731 .Doc(R"doc( 1732 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified 1733 linear gradients for Relu operation. 1734 1735 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1736 expected to invoke these operators. 1737 )doc"); 1738 1739 REGISTER_OP("_MklElu") 1740 .Input("features: T") 1741 .Input("mkl_features: uint8") 1742 .Output("activations: T") 1743 .Output("mkl_activations: uint8") 1744 .Attr("T: realnumbertype") 1745 .SetShapeFn(shape_inference::UnchangedShape) 1746 .Doc(R"doc( 1747 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator. 1748 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1749 expected to invoke these operators. 1750 )doc"); 1751 1752 REGISTER_OP("_MklEluGrad") 1753 .Input("gradients: T") 1754 .Input("features: T") 1755 .Input("mkl_gradients: uint8") 1756 .Input("mkl_features: uint8") 1757 .Output("backprops: T") 1758 .Output("mkl_backprops: uint8") 1759 .Attr("T: realnumbertype") 1760 .SetShapeFn(shape_inference::MergeBothInputsShapeFn) 1761 .Doc(R"doc( 1762 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu 1763 gradients for Elu operation. 1764 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1765 expected to invoke these operators. 1766 )doc"); 1767 1768 REGISTER_OP("_MklSoftmax") 1769 .Input("logits: T") 1770 .Input("mkl_logits: uint8") 1771 .Output("softmax: T") 1772 .Output("mkl_softmax: uint8") 1773 .Attr("T: {half, float, double}") 1774 .SetShapeFn([](InferenceContext* c) { 1775 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); 1776 }) 1777 .Doc(R"doc( 1778 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified 1779 linear gradients for Relu operation. 1780 )doc"); 1781 1782 REGISTER_OP("_MklTanh") 1783 .Input("features: T") 1784 .Input("mkl_features: uint8") 1785 .Output("activations: T") 1786 .Output("mkl_activations: uint8") 1787 .Attr("T: realnumbertype") 1788 .SetShapeFn(shape_inference::UnchangedShape) 1789 .Doc(R"doc( 1790 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator. 1791 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1792 expected to invoke these operators. 1793 )doc"); 1794 1795 REGISTER_OP("_MklTanhGrad") 1796 .Input("gradients: T") 1797 .Input("features: T") 1798 .Input("mkl_gradients: uint8") 1799 .Input("mkl_features: uint8") 1800 .Output("backprops: T") 1801 .Output("mkl_backprops: uint8") 1802 .Attr("T: realnumbertype") 1803 .SetShapeFn(shape_inference::MergeBothInputsShapeFn) 1804 .Doc(R"doc( 1805 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh 1806 gradients for Tanh operation. 1807 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1808 expected to invoke these operators. 1809 )doc"); 1810 1811 REGISTER_OP("_MklMaxPool") 1812 .Attr("T: {float, half} = DT_FLOAT") 1813 .Attr("ksize: list(int) >= 4") 1814 .Attr("strides: list(int) >= 4") 1815 .Attr(GetPaddingAttrString()) 1816 .Attr(GetConvnetDataFormatAttrString()) 1817 .Attr("workspace_enabled: bool = false") 1818 .Input("input: T") 1819 .Input("mkl_input: uint8") 1820 .Output("output: T") 1821 #ifdef INTEL_MKL_ML 1822 .Output("workspace: T") 1823 #else 1824 .Output("workspace: uint8") 1825 #endif 1826 .Output("mkl_output: uint8") 1827 .Output("mkl_workspace: uint8") 1828 .SetShapeFn(shape_inference::MaxPoolShape) 1829 .Doc(R"doc( 1830 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling 1831 on the input. 1832 1833 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1834 expected to invoke these operators. 1835 )doc"); 1836 1837 REGISTER_OP("_MklMaxPoolGrad") 1838 .Attr("T: {float, half} = DT_FLOAT") 1839 .Attr("ksize: list(int) >= 4") 1840 .Attr("strides: list(int) >= 4") 1841 .Attr("workspace_enabled: bool = false") 1842 .Attr(GetPaddingAttrString()) 1843 .Attr(GetConvnetDataFormatAttrString()) 1844 .Input("orig_input: T") 1845 .Input("orig_output: T") 1846 .Input("grad: T") 1847 #ifdef INTEL_MKL_ML 1848 .Input("workspace: T") 1849 #else 1850 .Input("workspace: uint8") 1851 #endif 1852 .Input("mkl_orig_input: uint8") 1853 .Input("mkl_orig_output: uint8") 1854 .Input("mkl_grad: uint8") 1855 .Input("mkl_workspace: uint8") 1856 .Output("output: T") 1857 .Output("mkl_output: uint8") 1858 .SetShapeFn([](InferenceContext* c) { 1859 return UnchangedShapeWithRank(c, 4); 1860 }) 1861 .Doc(R"doc( 1862 MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of 1863 MaxPool operator. 1864 1865 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1866 expected to invoke these operators. 1867 )doc"); 1868 1869 REGISTER_OP("_MklAvgPool") 1870 .Input("value: T") 1871 .Input("mkl_input: uint8") 1872 .Output("output: T") 1873 .Output("mkl_output: uint8") 1874 .Attr("ksize: list(int) >= 4") 1875 .Attr("strides: list(int) >= 4") 1876 .Attr(GetPaddingAttrString()) 1877 .Attr(GetConvnetDataFormatAttrString()) 1878 .Attr("T: {float, half, double}") 1879 .SetShapeFn(shape_inference::AvgPoolShape) 1880 .Doc(R"doc( 1881 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling 1882 on the input. 1883 1884 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1885 expected to invoke these operators. 1886 )doc"); 1887 1888 REGISTER_OP("_MklAvgPoolGrad") 1889 .Input("orig_input_shape: int32") 1890 .Input("grad: T") 1891 .Input("mkl_orig_input: uint8") 1892 .Input("mkl_grad: uint8") 1893 .Output("output: T") 1894 .Output("mkl_output: uint8") 1895 .Attr("ksize: list(int) >= 4") 1896 .Attr("strides: list(int) >= 4") 1897 .Attr(GetPaddingAttrString()) 1898 .Attr(GetConvnetDataFormatAttrString()) 1899 .Attr("T: {float, half, double}") 1900 .SetShapeFn([](InferenceContext* c) { 1901 ShapeHandle s; 1902 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); 1903 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); 1904 c->set_output(0, s); 1905 return Status::OK(); 1906 }) 1907 .Doc(R"doc( 1908 MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients 1909 of AvgPool function. 1910 1911 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1912 expected to invoke these operators. 1913 )doc"); 1914 1915 REGISTER_OP("_MklLRN") 1916 .Input("input: T") 1917 .Input("mkl_input: uint8") 1918 .Output("output: T") 1919 #ifdef INTEL_MKL_ML 1920 .Output("workspace: T") 1921 #else 1922 .Output("workspace: uint8") 1923 #endif 1924 .Output("mkl_output: uint8") 1925 .Output("mkl_workspace: uint8") 1926 .Attr("depth_radius: int = 5") 1927 .Attr("bias: float = 1.0") 1928 .Attr("alpha: float = 1.0") 1929 .Attr("beta: float = 0.5") 1930 .Attr("workspace_enabled: bool = false") 1931 .Attr("T: {float, half} = DT_FLOAT") 1932 .SetShapeFn([](InferenceContext* c) { 1933 return UnchangedShapeWithRank(c, 4); 1934 }) 1935 .Doc(R"doc( 1936 MKL version of LRN operator. Uses MKL DNN APIs to perform local response 1937 normalization. 1938 1939 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1940 expected to invoke these operators. 1941 )doc"); 1942 1943 REGISTER_OP("_MklLRNGrad") 1944 .Input("input_grads: T") 1945 .Input("input_image: T") 1946 .Input("output_image: T") 1947 #ifdef INTEL_MKL_ML 1948 .Input("workspace: T") 1949 #else 1950 .Input("workspace: uint8") 1951 #endif 1952 .Input("mkl_input_grads: uint8") 1953 .Input("mkl_input_image: uint8") 1954 .Input("mkl_output_image: uint8") 1955 .Input("mkl_workspace: uint8") 1956 .Output("output: T") 1957 .Output("mkl_output: uint8") 1958 .Attr("depth_radius: int = 5") 1959 .Attr("bias: float = 1.0") 1960 .Attr("alpha: float = 1.0") 1961 .Attr("beta: float = 0.5") 1962 .Attr("workspace_enabled: bool = false") 1963 .Attr("T: {float, half} = DT_FLOAT") 1964 .SetShapeFn([](InferenceContext* c) { 1965 ShapeHandle s; 1966 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads 1967 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image 1968 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image 1969 c->set_output(0, s); 1970 return Status::OK(); 1971 }) 1972 .Doc(R"doc( 1973 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for 1974 local response normalization. 1975 1976 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 1977 expected to invoke these operators. 1978 )doc"); 1979 1980 REGISTER_OP("_MklFusedBatchNorm") 1981 .Input("x: T") 1982 .Input("scale: T") 1983 .Input("offset: T") 1984 .Input("mean: T") 1985 .Input("variance: T") 1986 .Input("mkl_x: uint8") 1987 .Input("mkl_scale: uint8") 1988 .Input("mkl_offset: uint8") 1989 .Input("mkl_mean: uint8") 1990 .Input("mkl_variance: uint8") 1991 .Output("y: T") 1992 .Output("batch_mean: T") 1993 .Output("batch_variance: T") 1994 .Output("reserve_space_1: T") 1995 .Output("reserve_space_2: T") 1996 .Output("mkl_y: uint8") 1997 .Output("mkl_batch_mean: uint8") 1998 .Output("mkl_batch_variance: uint8") 1999 .Output("mkl_reserve_space_1: uint8") 2000 .Output("mkl_reserve_space_2: uint8") 2001 .Attr("T: numbertype") 2002 .Attr("epsilon: float = 0.0001") 2003 .Attr("data_format: string = 'NHWC'") 2004 .Attr("is_training: bool = true") 2005 .SetShapeFn([](InferenceContext* c) { 2006 ShapeHandle x; 2007 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); 2008 2009 bool is_training; 2010 c->GetAttr("is_training", &is_training); 2011 int number_inputs = (is_training) ? 3 : 5; 2012 string data_format; 2013 c->GetAttr("data_format", &data_format); 2014 DimensionHandle channel_dim = 2015 (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1); 2016 2017 // covers scale, offset, and if is_training is false, mean, variance 2018 for (int i = 1; i < number_inputs; ++i) { 2019 ShapeHandle vec; 2020 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 2021 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 2022 } 2023 2024 ShapeHandle y; 2025 if (data_format == "NHWC") { 2026 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y)); 2027 } else { 2028 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y)); 2029 } 2030 c->set_output(0, y); 2031 ShapeHandle vector_shape = c->Vector(channel_dim); 2032 c->set_output(1, vector_shape); 2033 c->set_output(2, vector_shape); 2034 c->set_output(3, vector_shape); 2035 c->set_output(4, vector_shape); 2036 return Status::OK(); 2037 }) 2038 .Doc(R"doc( 2039 MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused 2040 batch normalization. 2041 2042 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 2043 expected to invoke these operators. 2044 )doc"); 2045 2046 REGISTER_OP("_MklFusedBatchNormGrad") 2047 .Input("y_backprop: T") 2048 .Input("x: T") 2049 .Input("scale: T") 2050 .Input("reserve_space_1: T") 2051 .Input("reserve_space_2: T") 2052 .Input("mkl_y_backprop: uint8") 2053 .Input("mkl_x: uint8") 2054 .Input("mkl_scale: uint8") 2055 .Input("mkl_reserve_space_1: uint8") 2056 .Input("mkl_reserve_space_2: uint8") 2057 .Output("x_backprop: T") 2058 .Output("scale_backprop: T") 2059 .Output("offset_backprop: T") 2060 .Output("reserve_space_3: T") 2061 .Output("reserve_space_4: T") 2062 .Output("mkl_x_backprop: uint8") 2063 .Output("mkl_scale_backprop: uint8") 2064 .Output("mkl_offset_backprop: uint8") 2065 .Output("mkl_reserve_space_3: uint8") 2066 .Output("mkl_reserve_space_4: uint8") 2067 .Attr("T: numbertype") 2068 .Attr("epsilon: float = 0.0001") 2069 .Attr("data_format: string = 'NHWC'") 2070 .Attr("is_training: bool = true") 2071 .SetShapeFn([](InferenceContext* c) { 2072 ShapeHandle y_backprop; 2073 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); 2074 ShapeHandle x; 2075 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); 2076 2077 bool is_training; 2078 string data_format; 2079 c->GetAttr("is_training", &is_training); 2080 c->GetAttr("data_format", &data_format); 2081 DimensionHandle channel_dim = (data_format == "NHWC") 2082 ? c->Dim(y_backprop, 3) 2083 : c->Dim(y_backprop, 1); 2084 if (data_format == "NHWC") { 2085 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim)); 2086 } else { 2087 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim)); 2088 } 2089 2090 // covers scale, mean (reserve_space_1), variance (reserve_space_2) 2091 for (int i = 2; i < 5; ++i) { 2092 ShapeHandle vec; 2093 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 2094 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 2095 } 2096 2097 ShapeHandle x_backprop; 2098 if (data_format == "NHWC") { 2099 TF_RETURN_IF_ERROR( 2100 c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop)); 2101 } else { 2102 TF_RETURN_IF_ERROR( 2103 c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop)); 2104 } 2105 c->set_output(0, x_backprop); 2106 c->set_output(1, c->Vector(channel_dim)); 2107 c->set_output(2, c->Vector(channel_dim)); 2108 // Set the correct shapes for reserve_spaces 2109 // so that gradients can be performed when 2110 // the op is in a symbolic condition. 2111 if (is_training) { 2112 c->set_output(3, c->Vector(0)); 2113 c->set_output(4, c->Vector(0)); 2114 } else { 2115 c->set_output(3, c->Vector(channel_dim)); 2116 c->set_output(4, c->Vector(channel_dim)); 2117 } 2118 return Status::OK(); 2119 }) 2120 .Doc(R"doc( 2121 MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute 2122 gradients for fused batch normalization. 2123 2124 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 2125 expected to invoke these operators. 2126 )doc"); 2127 2128 REGISTER_OP("_MklToTf") 2129 .Input("input: T") 2130 .Input("mkl_input: uint8") 2131 .Output("output: T") 2132 .Attr("T: {half, float, double}") 2133 .Attr(GetConvnetDataFormatAttrString()) 2134 .Doc(R"doc( 2135 MKL operator to convert a tensor from MKL layout to TensorFlow layout. 2136 2137 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 2138 expected to invoke these operators. 2139 )doc"); 2140 2141 REGISTER_OP("_MklInputConversion") 2142 .Input("input_0: T") 2143 .Input("input_1: T") 2144 .Input("mkl_input_0: uint8") 2145 .Input("mkl_input_1: uint8") 2146 .Output("output_0: T") 2147 .Output("output_1: T") 2148 .Output("mkl_output_0: uint8") 2149 .Output("mkl_output_1: uint8") 2150 // All datatypes supported by element-wise ops 2151 .Attr( 2152 "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " 2153 "complex64, complex128}") 2154 .Attr(GetConvnetDataFormatAttrString()) 2155 .Doc(R"doc( 2156 MKL operator to process the inputs to an elementwise MKL op. Both inputs 2157 need to be either in TF or in MKL format. This op is added before every 2158 element-wise MKL op. 2159 2160 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is 2161 expected to invoke these operators. 2162 )doc"); 2163 #endif // INTEL_MKL 2164 2165 } // namespace tensorflow 2166