1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/numeric_op.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 23 using shape_inference::DimensionHandle; 24 using shape_inference::InferenceContext; 25 using shape_inference::ShapeHandle; 26 27 REGISTER_OP("AddN") 28 .Input("inputs: N * T") 29 .Output("sum: T") 30 .Attr("N: int >= 1") 31 .Attr("T: {numbertype, variant}") 32 .SetIsCommutative() 33 .SetIsAggregate() 34 .SetShapeFn([](InferenceContext* c) { 35 ShapeHandle cur = c->input(c->num_inputs() - 1); 36 for (int i = c->num_inputs() - 2; i >= 0; --i) { 37 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 38 "From merging shape ", i, 39 " with other shapes."); 40 } 41 c->set_output(0, cur); 42 43 DataType dtype; 44 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); 45 46 if (dtype != DT_VARIANT) { 47 // Exit early if not DT_VARIANT. 48 return Status::OK(); 49 } else { 50 // DT_VARIANT shape handle shape inference. All sizes and dtypes must 51 // be the same; all shapes must be compatible via Merge. 52 std::vector<shape_inference::ShapeAndType> cur_shapes_and_types; 53 auto* shapes_and_types = 54 c->input_handle_shapes_and_types(c->num_inputs() - 1); 55 if (shapes_and_types) { 56 cur_shapes_and_types = *shapes_and_types; 57 } 58 59 for (int i = c->num_inputs() - 2; i >= 0; --i) { 60 auto shapes_and_types_i = c->input_handle_shapes_and_types(i); 61 if (!shapes_and_types && shapes_and_types_i) { 62 // TODO(ebrevdo): Find cases where this happens and fix their shape 63 // inference. If we are calling AddN on variant types, they should 64 // all have consistent shape_and_type info. 65 shapes_and_types = shapes_and_types_i; 66 } else if (shapes_and_types && shapes_and_types_i) { 67 if (shapes_and_types_i->size() != shapes_and_types->size()) { 68 return errors::InvalidArgument( 69 "shapes_and_types[", i, 70 "].size() == ", shapes_and_types_i->size(), 71 " != shapes_and_types[0].size() == ", 72 shapes_and_types->size()); 73 } 74 for (int j = 0; j < shapes_and_types->size(); ++j) { 75 if (shapes_and_types->at(j).dtype != 76 shapes_and_types_i->at(j).dtype) { 77 return errors::InvalidArgument( 78 "shapes_and_types[", i, "][", j, "].dtype() == ", 79 DataTypeString(shapes_and_types_i->at(j).dtype), 80 " != shapes_and_types[0][", j, "].dtype == ", 81 DataTypeString(shapes_and_types->at(j).dtype)); 82 } 83 TF_RETURN_WITH_CONTEXT_IF_ERROR( 84 c->Merge(shapes_and_types_i->at(j).shape, 85 cur_shapes_and_types.at(j).shape, 86 &cur_shapes_and_types.at(j).shape), 87 "From merging shapes_and_types[", i, "][", j, "].shape with ", 88 "shapes_and_types[0][", j, "].shape"); 89 } 90 } 91 } 92 if (shapes_and_types) { 93 c->set_output_handle_shapes_and_types(0, cur_shapes_and_types); 94 } 95 return Status::OK(); 96 } 97 }); 98 99 // -------------------------------------------------------------------------- 100 101 // Note that the following operator is just a placeholder and has no 102 // associated kernel. The code in accumulate_n_optimizer.cc replaces 103 // this placeholder with a graph of operators that do have kernels. 104 // The Python code that generates instances of this op is currently in 105 // contrib/framework/python/ops/accumulate_n_v2.py 106 REGISTER_OP("AccumulateNV2") 107 .Input("inputs: N * T") 108 .Output("sum: T") 109 .Attr("N: int >= 1") 110 .Attr("T: numbertype") 111 .Attr("shape: shape") 112 .SetIsCommutative() 113 .SetIsAggregate() 114 .SetShapeFn(shape_inference::ExplicitShape); 115 116 // -------------------------------------------------------------------------- 117 118 REGISTER_OP("BatchMatMul") 119 .Input("x: T") 120 .Input("y: T") 121 .Output("output: T") 122 .Attr( 123 "T: {bfloat16, half, float, double, int32, int64, complex64, " 124 "complex128}") 125 .Attr("adj_x: bool = false") 126 .Attr("adj_y: bool = false") 127 .SetShapeFn([](InferenceContext* c) { 128 ShapeHandle a_shape; 129 ShapeHandle b_shape; 130 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); 131 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); 132 133 // Determine output rows and cols. 134 bool adj_x; 135 bool adj_y; 136 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x)); 137 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y)); 138 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2); 139 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1); 140 141 // Batch dims match between inputs. 142 ShapeHandle a_batch_dims; 143 ShapeHandle b_batch_dims; 144 ShapeHandle batch_dims; 145 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); 146 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); 147 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); 148 149 // Assert inner dims match. 150 DimensionHandle unused; 151 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1), 152 c->Dim(b_shape, adj_y ? -1 : -2), &unused)); 153 154 ShapeHandle out; 155 TF_RETURN_IF_ERROR(c->Concatenate( 156 batch_dims, c->Matrix(output_rows, output_cols), &out)); 157 c->set_output(0, out); 158 return Status::OK(); 159 }); 160 161 // -------------------------------------------------------------------------- 162 // Casting Ops 163 // 164 // NOTE: Only a smaller number of types are supported by 165 // Cast. The exact casting rule is TBD. The current 166 // implementation uses C++ static cast rules for numeric 167 // types, which may be changed in the future. 168 REGISTER_OP("Cast") 169 .Input("x: SrcT") 170 .Output("y: DstT") 171 .Attr("SrcT: type") 172 .Attr("DstT: type") 173 .Attr("Truncate: bool = false") 174 .SetShapeFn(shape_inference::UnchangedShape); 175 176 REGISTER_OP("_HostCast") 177 .Input("x: SrcT") 178 .Output("y: DstT") 179 .Attr("SrcT: type") 180 .Attr("DstT: type") 181 .Attr("Truncate: bool = false") 182 .SetShapeFn(shape_inference::UnchangedShape) 183 .Doc(R"doc( 184 Cast x of type SrcT to y of DstT. 185 186 _HostCast requires its input and produces its output in host memory. 187 )doc"); 188 189 // -------------------------------------------------------------------------- 190 191 REGISTER_OP("Abs") 192 .Input("x: T") 193 .Output("y: T") 194 .Attr("T: {bfloat16, half, float, double, int32, int64}") 195 .SetShapeFn(shape_inference::UnchangedShape); 196 197 REGISTER_OP("ComplexAbs") 198 .Input("x: T") 199 .Output("y: Tout") 200 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 201 .Attr("Tout: {float, double} = DT_FLOAT") 202 .SetShapeFn(shape_inference::UnchangedShape); 203 204 // Declares cwise unary operations signature: 't -> 't 205 #define UNARY() \ 206 Input("x: T") \ 207 .Output("y: T") \ 208 .Attr( \ 209 "T: {bfloat16, half, float, double, int32, int64, complex64, " \ 210 "complex128}") \ 211 .SetShapeFn(shape_inference::UnchangedShape) 212 213 #define UNARY_REAL() \ 214 Input("x: T") \ 215 .Output("y: T") \ 216 .Attr("T: {bfloat16, half, float, double}") \ 217 .SetShapeFn(shape_inference::UnchangedShape) 218 219 #define UNARY_COMPLEX() \ 220 Input("x: T") \ 221 .Output("y: T") \ 222 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \ 223 .SetShapeFn(shape_inference::UnchangedShape) 224 225 #define UNARY_GRADIENT_COMPLEX() \ 226 Input("y: T") \ 227 .Input("dy: T") \ 228 .Output("z: T") \ 229 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \ 230 .SetShapeFn(shape_inference::UnchangedShape) 231 232 REGISTER_OP("Neg").UNARY(); 233 234 REGISTER_OP("Inv").UNARY(); 235 236 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX(); 237 238 REGISTER_OP("Reciprocal").UNARY(); 239 240 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX(); 241 242 REGISTER_OP("Square").UNARY(); 243 244 REGISTER_OP("Sqrt").UNARY_COMPLEX(); 245 246 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX(); 247 248 REGISTER_OP("Rsqrt").UNARY_COMPLEX(); 249 250 REGISTER_OP("Round").UNARY(); 251 252 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX(); 253 254 REGISTER_OP("Exp").UNARY_COMPLEX(); 255 256 REGISTER_OP("Expm1").UNARY_COMPLEX(); 257 258 REGISTER_OP("Log").UNARY_COMPLEX(); 259 260 REGISTER_OP("Log1p").UNARY_COMPLEX(); 261 262 REGISTER_OP("Sinh").UNARY_COMPLEX(); 263 264 REGISTER_OP("Cosh").UNARY_COMPLEX(); 265 266 REGISTER_OP("Tanh").UNARY_COMPLEX(); 267 268 REGISTER_OP("Asinh").UNARY_COMPLEX(); 269 270 REGISTER_OP("Acosh").UNARY_COMPLEX(); 271 272 REGISTER_OP("Atanh").UNARY_COMPLEX(); 273 274 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX(); 275 276 REGISTER_OP("Lgamma").UNARY_REAL(); 277 278 REGISTER_OP("Digamma").UNARY_REAL(); 279 280 REGISTER_OP("Erf").UNARY_REAL(); 281 282 REGISTER_OP("Erfc").UNARY_REAL(); 283 284 REGISTER_OP("Sigmoid").UNARY_COMPLEX(); 285 286 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX(); 287 288 REGISTER_OP("Sin").UNARY_COMPLEX(); 289 290 REGISTER_OP("Cos").UNARY_COMPLEX(); 291 292 REGISTER_OP("Tan").UNARY(); 293 294 REGISTER_OP("Asin").UNARY(); 295 296 REGISTER_OP("Acos").UNARY(); 297 298 REGISTER_OP("Atan").UNARY(); 299 300 REGISTER_OP("BesselI0e").UNARY_REAL(); 301 302 REGISTER_OP("BesselI1e").UNARY_REAL(); 303 304 REGISTER_OP("_UnaryOpsComposition") 305 .Input("x: T") 306 .Output("y: T") 307 .Attr("T: {float, half, double}") 308 .Attr("op_names: list(string)") 309 .SetShapeFn(shape_inference::UnchangedShape) 310 .Doc(R"doc( 311 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is 312 expected to create these operators. 313 )doc"); 314 315 #undef UNARY 316 #undef UNARY_REAL 317 #undef UNARY_COMPLEX 318 319 REGISTER_OP("IsNan") 320 .Input("x: T") 321 .Output("y: bool") 322 .Attr("T: {bfloat16, half, float, double}") 323 .SetShapeFn(shape_inference::UnchangedShape); 324 325 REGISTER_OP("IsInf") 326 .Input("x: T") 327 .Output("y: bool") 328 .Attr("T: {bfloat16, half, float, double}") 329 .SetShapeFn(shape_inference::UnchangedShape); 330 331 REGISTER_OP("IsFinite") 332 .Input("x: T") 333 .Output("y: bool") 334 .Attr("T: {bfloat16, half, float, double}") 335 .SetShapeFn(shape_inference::UnchangedShape); 336 337 REGISTER_OP("Sign") 338 .Input("x: T") 339 .Output("y: T") 340 .Attr( 341 "T: {bfloat16, half, float, double, int32, int64, complex64, " 342 "complex128}") 343 .SetShapeFn(shape_inference::UnchangedShape); 344 345 REGISTER_OP("Floor") 346 .Input("x: T") 347 .Output("y: T") 348 .Attr("T: {bfloat16, half, float, double}") 349 .SetShapeFn(shape_inference::UnchangedShape); 350 351 REGISTER_OP("Ceil") 352 .Input("x: T") 353 .Output("y: T") 354 .Attr("T: {bfloat16, half, float, double}") 355 .SetShapeFn(shape_inference::UnchangedShape); 356 357 REGISTER_OP("Rint") 358 .Input("x: T") 359 .Output("y: T") 360 .Attr("T: {bfloat16, half, float, double}") 361 .SetShapeFn(shape_inference::UnchangedShape); 362 363 // Declares cwise binary operations signature: 't, 't -> 't. 364 365 #define BINARY_MORE() \ 366 Input("x: T").Input("y: T").Output("z: T").Attr( \ 367 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \ 368 "int64, complex64, complex128}") 369 370 #define BINARY_FEWER() \ 371 Input("x: T").Input("y: T").Output("z: T").Attr( \ 372 "T: {bfloat16, half, float, double, int32, int64, complex64, " \ 373 "complex128}") 374 375 REGISTER_OP("Add") 376 .Input("x: T") 377 .Input("y: T") 378 .Output("z: T") 379 .Attr( 380 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " 381 "complex64, complex128, string}") 382 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 383 384 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to 385 // use AddV2 (b/68646025). 386 REGISTER_OP("AddV2") 387 .Input("x: T") 388 .Input("y: T") 389 .Output("z: T") 390 .Attr( 391 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " 392 "complex64, complex128}") 393 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 394 .SetIsAggregate() 395 .SetIsCommutative(); 396 397 REGISTER_OP("_MklAdd") 398 .Input("x: T") 399 .Input("y: T") 400 .Input("mkl_x: uint8") 401 .Input("mkl_y: uint8") 402 .Output("z: T") 403 .Output("mkl_z: uint8") 404 .Attr( 405 "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " 406 "complex128, string}") 407 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 408 .Doc(R"doc( 409 Returns `x` + `y` element-wise. 410 411 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting 412 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). 413 )doc"); 414 415 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn( 416 shape_inference::BroadcastBinaryOpShapeFn); 417 418 REGISTER_OP("_MklSub") 419 .BINARY_FEWER() 420 .Input("mkl_x: uint8") 421 .Input("mkl_y: uint8") 422 .Output("mkl_z: uint8") 423 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 424 .Doc(R"doc( 425 Returns x - y element-wise. 426 427 *NOTE*: `Sub` supports broadcasting. More about broadcasting 428 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 429 )doc"); 430 431 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn( 432 shape_inference::BroadcastBinaryOpShapeFn); 433 434 REGISTER_OP("MulNoNan") 435 .Input("x: T") 436 .Input("y: T") 437 .Output("z: T") 438 .Attr("T: {half, float, double, complex64, complex128}") 439 .SetIsCommutative() 440 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 441 442 REGISTER_OP("_MklMul") 443 .BINARY_MORE() 444 .Input("mkl_x: uint8") 445 .Input("mkl_y: uint8") 446 .Output("mkl_z: uint8") 447 .SetIsCommutative() 448 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 449 .Doc(R"doc( 450 Returns x * y element-wise. 451 452 *NOTE*: `Mul` supports broadcasting. More about broadcasting 453 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 454 )doc"); 455 456 REGISTER_OP("Div").BINARY_MORE().SetShapeFn( 457 shape_inference::BroadcastBinaryOpShapeFn); 458 459 REGISTER_OP("DivNoNan") 460 .Input("x: T") 461 .Input("y: T") 462 .Output("z: T") 463 .Attr("T: {half, float, double, complex64, complex128}") 464 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 465 466 REGISTER_OP("FloorDiv") 467 .BINARY_MORE() 468 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 469 470 REGISTER_OP("TruncateDiv") 471 .BINARY_MORE() 472 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 473 474 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn( 475 shape_inference::BroadcastBinaryOpShapeFn); 476 477 REGISTER_OP("SquaredDifference") 478 .BINARY_FEWER() 479 .SetIsCommutative() 480 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 481 482 REGISTER_OP("_MklSquaredDifference") 483 .BINARY_FEWER() 484 .Input("mkl_x: uint8") 485 .Input("mkl_y: uint8") 486 .Output("mkl_z: uint8") 487 .SetIsCommutative() 488 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 489 .Doc(R"doc( 490 Returns (x - y)(x - y) element-wise. 491 492 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting 493 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 494 )doc"); 495 496 REGISTER_OP("Xlogy") 497 .Input("x: T") 498 .Input("y: T") 499 .Output("z: T") 500 .Attr("T: {half, float, double, complex64, complex128}") 501 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 502 503 REGISTER_OP("Xdivy") 504 .Input("x: T") 505 .Input("y: T") 506 .Output("z: T") 507 .Attr("T: {half, float, double, complex64, complex128}") 508 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 509 510 #undef BINARY_FEWER 511 #undef BINARY_MORE 512 513 REGISTER_OP("Maximum") 514 .Input("x: T") 515 .Input("y: T") 516 .Output("z: T") 517 .Attr("T: {bfloat16, half, float, double, int32, int64}") 518 .SetIsCommutative() 519 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 520 521 REGISTER_OP("_MklMaximum") 522 .Input("x: T") 523 .Input("y: T") 524 .Input("mkl_x: uint8") 525 .Input("mkl_y: uint8") 526 .Output("z: T") 527 .Output("mkl_z: uint8") 528 .Attr("T: {half, float, double, int32, int64}") 529 .SetIsCommutative() 530 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 531 .Doc(R"doc( 532 Returns the max of x and y (i.e. x > y ? x : y) element-wise. 533 534 *NOTE*: `Maximum` supports broadcasting. More about broadcasting 535 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 536 )doc"); 537 538 REGISTER_OP("Minimum") 539 .Input("x: T") 540 .Input("y: T") 541 .Output("z: T") 542 .Attr("T: {bfloat16, half, float, double, int32, int64}") 543 .SetIsCommutative() 544 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 545 546 REGISTER_OP("Mod") 547 .Input("x: T") 548 .Input("y: T") 549 .Output("z: T") 550 .Attr("T: {int32, int64, float16, half, bfloat16, float, double}") 551 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 552 553 REGISTER_OP("FloorMod") 554 .Input("x: T") 555 .Input("y: T") 556 .Output("z: T") 557 .Attr("T: {int32, int64, bfloat16, half, float, double}") 558 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 559 560 REGISTER_OP("TruncateMod") 561 .Input("x: T") 562 .Input("y: T") 563 .Output("z: T") 564 .Attr("T: {int32, int64, bfloat16, half, float, double}") 565 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 566 567 REGISTER_OP("Pow") 568 .Input("x: T") 569 .Input("y: T") 570 .Output("z: T") 571 .Attr( 572 "T: {bfloat16, float, half, double, int32, int64, complex64, " 573 "complex128}") 574 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 575 576 REGISTER_OP("Igammac") 577 .Input("a: T") 578 .Input("x: T") 579 .Output("z: T") 580 .Attr("T: {float, double}") 581 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 582 583 REGISTER_OP("Igamma") 584 .Input("a: T") 585 .Input("x: T") 586 .Output("z: T") 587 .Attr("T: {float, double}") 588 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 589 590 REGISTER_OP("IgammaGradA") 591 .Input("a: T") 592 .Input("x: T") 593 .Output("z: T") 594 .Attr("T: {float, double}") 595 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 596 597 REGISTER_OP("Zeta") 598 .Input("x: T") 599 .Input("q: T") 600 .Output("z: T") 601 .Attr("T: {float, double}") 602 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 603 604 REGISTER_OP("Polygamma") 605 .Input("a: T") 606 .Input("x: T") 607 .Output("z: T") 608 .Attr("T: {float, double}") 609 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 610 611 REGISTER_OP("Atan2") 612 .Input("y: T") 613 .Input("x: T") 614 .Output("z: T") 615 .Attr("T: {bfloat16, half, float, double}") 616 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 617 618 REGISTER_OP("Betainc") 619 .Input("a: T") 620 .Input("b: T") 621 .Input("x: T") 622 .Output("z: T") 623 .Attr("T: {float, double}") 624 .SetShapeFn([](InferenceContext* c) { 625 const int num_inputs = 3; 626 ShapeHandle output = c->UnknownShape(); 627 int num_scalars = 0; 628 ShapeHandle some_non_scalar; 629 for (int i = 0; i < num_inputs; ++i) { 630 ShapeHandle in = c->input(i); 631 if (!c->RankKnown(in)) { 632 some_non_scalar = in; 633 // An input with unknown rank could be either a scalar (to be 634 // broadcast) or some other shape. 635 } else if (c->Rank(in) == 0) { 636 // Input is a scalar, it will be broadcast to the output shape. 637 ++num_scalars; 638 } else { 639 TF_RETURN_IF_ERROR(c->Merge(output, in, &output)); 640 some_non_scalar = output; 641 } 642 } 643 644 if (num_scalars == num_inputs - 1) { 645 // If all but one input is known to be a scalar, then output is the 646 // remaining input. 647 output = some_non_scalar; 648 } else if (num_scalars == num_inputs) { 649 // If all are scalars, output is scalar; pick the first one arbitrarily. 650 output = c->input(0); 651 } 652 653 c->set_output(0, output); 654 return Status::OK(); 655 }); 656 657 // -------------------------------------------------------------------------- 658 659 // Declares cwise binary comparison operations signature: 't, 't -> bool, 660 // where 't has a natural total order. 661 #define COMPARISON() \ 662 Input("x: T") \ 663 .Input("y: T") \ 664 .Output("z: bool") \ 665 .Attr("T: realnumbertype") \ 666 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 667 668 REGISTER_OP("Less").COMPARISON(); 669 670 REGISTER_OP("LessEqual").COMPARISON(); 671 672 REGISTER_OP("Greater").COMPARISON(); 673 674 REGISTER_OP("GreaterEqual").COMPARISON(); 675 676 #undef COMPARISON 677 678 // -------------------------------------------------------------------------- 679 680 #define EQUALITY_COMPARISON() \ 681 Input("x: T") \ 682 .Input("y: T") \ 683 .Output("z: bool") \ 684 .SetIsCommutative() \ 685 .Attr( \ 686 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ 687 "int64, complex64, quint8, qint8, qint32, string, bool, " \ 688 "complex128}") \ 689 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 690 691 REGISTER_OP("Equal").EQUALITY_COMPARISON(); 692 693 REGISTER_OP("NotEqual").EQUALITY_COMPARISON(); 694 695 #undef EQUALITY_COMPARISON 696 697 REGISTER_OP("ApproximateEqual") 698 .Input("x: T") 699 .Input("y: T") 700 .Output("z: bool") 701 .SetIsCommutative() 702 .Attr("T: numbertype") 703 .Attr("tolerance: float = 0.00001") 704 .SetShapeFn([](InferenceContext* c) { 705 // The inputs 'x' and 'y' must have the same shape. 706 ShapeHandle data_x = c->input(0); 707 ShapeHandle data_y = c->input(1); 708 TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x)); 709 return shape_inference::UnchangedShape(c); 710 }); 711 712 // -------------------------------------------------------------------------- 713 714 REGISTER_OP("LogicalNot") 715 .Input("x: bool") 716 .Output("y: bool") 717 .SetShapeFn(shape_inference::UnchangedShape); 718 719 #define BINARY_LOGICAL() \ 720 Input("x: bool") \ 721 .Input("y: bool") \ 722 .Output("z: bool") \ 723 .SetIsCommutative() \ 724 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 725 726 REGISTER_OP("LogicalAnd").BINARY_LOGICAL(); 727 728 REGISTER_OP("LogicalOr").BINARY_LOGICAL(); 729 730 #undef BINARY_LOGICAL 731 732 // -------------------------------------------------------------------------- 733 734 REGISTER_OP("Select") 735 .Input("condition: bool") 736 .Input("t: T") 737 .Input("e: T") 738 .Output("output: T") 739 .Attr("T: type") 740 .SetShapeFn([](InferenceContext* c) { 741 auto* handle_data_1 = c->input_handle_shapes_and_types(1); 742 auto* handle_data_2 = c->input_handle_shapes_and_types(2); 743 // Merge handle shape and dtype if applicable. 744 if (handle_data_1 != nullptr && handle_data_2 != nullptr) { 745 const auto size = handle_data_1->size(); 746 std::vector<shape_inference::ShapeAndType> merged_handle_data(size); 747 if (size != handle_data_2->size()) { 748 return errors::InvalidArgument( 749 "Trying to merge handles pointing to different numbers of " 750 "tensors."); 751 } 752 753 for (int i = 0; i < size; ++i) { 754 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; 755 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; 756 if (s1.dtype != s2.dtype) { 757 // TODO(apassos) resolve this in the manner of b/32476923 758 return errors::InvalidArgument( 759 "Trying to merge handles pointing to different dtypes."); 760 } 761 merged_handle_data[i].dtype = s1.dtype; 762 TF_RETURN_IF_ERROR( 763 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); 764 } 765 766 c->set_output_handle_shapes_and_types(0, merged_handle_data); 767 } 768 769 // The inputs 'then' and 'else' must have the same shape. 770 ShapeHandle data = c->input(1); 771 ShapeHandle other = c->input(2); 772 TF_RETURN_IF_ERROR(c->Merge(data, other, &data)); 773 774 // The input 'cond' must either have the same shape as 'then' and 775 // 'else', or be a vector if 'then' and 'else' are at least vectors. 776 ShapeHandle cond = c->input(0); 777 778 if (!c->RankKnown(cond) || !c->RankKnown(data)) { 779 c->set_output(0, data); 780 return Status::OK(); 781 } 782 783 // rank of shape and data is known. 784 785 const int32 cond_rank = c->Rank(cond); 786 const int32 data_rank = c->Rank(data); 787 788 if (cond_rank == 0) { 789 // The rank of 'cond' is a scalar. 790 // t and e can have any shape. 791 c->set_output(0, data); 792 return Status::OK(); 793 } 794 795 if (cond_rank != 1) { 796 // If 'cond' is not a vector, and not a scalar, 797 // then shape must match 'then' and 'else' 798 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); 799 c->set_output(0, data); 800 return Status::OK(); 801 } 802 803 if (data_rank == 0) { 804 // if 'then' and 'else' are scalar also the cond must be 805 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); 806 c->set_output(0, data); 807 return Status::OK(); 808 } 809 810 if (cond_rank == 1) { 811 // if the cond is a vector and the 'then' is not a scalar, 812 // the first dimension of 'then' and 'else' 813 TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond)); 814 c->set_output(0, data); 815 return Status::OK(); 816 } 817 818 c->set_output(0, data); 819 820 return Status::OK(); 821 }); 822 823 // -------------------------------------------------------------------------- 824 825 REGISTER_OP("MatMul") 826 .Input("a: T") 827 .Input("b: T") 828 .Output("product: T") 829 .Attr("transpose_a: bool = false") 830 .Attr("transpose_b: bool = false") 831 .Attr( 832 "T: {bfloat16, half, float, double, int32, int64, complex64, " 833 "complex128}") 834 .SetShapeFn(shape_inference::MatMulShape); 835 836 REGISTER_OP("SparseMatMul") 837 .Input("a: Ta") 838 .Input("b: Tb") 839 .Output("product: float") 840 .Attr("transpose_a: bool = false") 841 .Attr("transpose_b: bool = false") 842 .Attr("a_is_sparse: bool = false") 843 .Attr("b_is_sparse: bool = false") 844 .Attr("Ta: {float, bfloat16} = DT_FLOAT") 845 .Attr("Tb: {float, bfloat16} = DT_FLOAT") 846 .SetShapeFn(shape_inference::MatMulShape); 847 848 // -------------------------------------------------------------------------- 849 850 // For operations where the output is a reduction function along some 851 // dimensions of the input. 852 REGISTER_OP("Sum") 853 .Input("input: T") 854 .Input("reduction_indices: Tidx") 855 .Output("output: T") 856 .Attr("keep_dims: bool = false") 857 .Attr("T: numbertype") 858 .Attr("Tidx: {int32, int64} = DT_INT32") 859 .SetShapeFn(shape_inference::ReductionShape); 860 861 REGISTER_OP("EuclideanNorm") 862 .Input("input: T") 863 .Input("reduction_indices: Tidx") 864 .Output("output: T") 865 .Attr("keep_dims: bool = false") 866 .Attr("T: numbertype") 867 .Attr("Tidx: {int32, int64} = DT_INT32") 868 .SetShapeFn(shape_inference::ReductionShape); 869 870 REGISTER_OP("Mean") 871 .Input("input: T") 872 .Input("reduction_indices: Tidx") 873 .Output("output: T") 874 .Attr("keep_dims: bool = false") 875 .Attr("T: numbertype") 876 .Attr("Tidx: {int32, int64} = DT_INT32") 877 .SetShapeFn(shape_inference::ReductionShape); 878 879 REGISTER_OP("Prod") 880 .Input("input: T") 881 .Input("reduction_indices: Tidx") 882 .Output("output: T") 883 .Attr("keep_dims: bool = false") 884 .Attr("T: numbertype") 885 .Attr("Tidx: {int32, int64} = DT_INT32") 886 .SetShapeFn(shape_inference::ReductionShape); 887 888 REGISTER_OP("Min") 889 .Input("input: T") 890 .Input("reduction_indices: Tidx") 891 .Output("output: T") 892 .Attr("keep_dims: bool = false") 893 .Attr("T: numbertype") 894 .Attr("Tidx: {int32, int64} = DT_INT32") 895 .SetShapeFn(shape_inference::ReductionShape); 896 897 REGISTER_OP("Max") 898 .Input("input: T") 899 .Input("reduction_indices: Tidx") 900 .Output("output: T") 901 .Attr("keep_dims: bool = false") 902 .Attr("T: numbertype") 903 .Attr("Tidx: {int32, int64} = DT_INT32") 904 .SetShapeFn(shape_inference::ReductionShape); 905 906 namespace { 907 908 Status ArgOpShape(shape_inference::InferenceContext* c) { 909 ShapeHandle dimension_shape; 910 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape)); 911 912 ShapeHandle input_shape = c->input(0); 913 if (!c->RankKnown(input_shape)) { 914 return shape_inference::UnknownShape(c); 915 } 916 917 const int32 input_rank = c->Rank(input_shape); 918 if (input_rank <= 1) { 919 // Reducing a scalar/vector must return a scalar. 920 return shape_inference::ScalarShape(c); 921 } 922 923 const Tensor* dim_t = c->input_tensor(1); 924 if (dim_t == nullptr) { 925 // We don't know the value of the dimension, but we 926 // know the rank of the input, so return the correct 927 // rank with unknown dimensions. 928 std::vector<DimensionHandle> dims(input_rank - 1); 929 for (int i = 0; i < dims.size(); ++i) { 930 dims[i] = c->UnknownDim(); 931 } 932 933 c->set_output(0, c->MakeShape(dims)); 934 return Status::OK(); 935 } 936 937 int64 dimension_val; 938 if (dim_t->dtype() == DT_INT32) { 939 dimension_val = dim_t->scalar<int32>()(); 940 } else { 941 dimension_val = dim_t->scalar<int64>()(); 942 } 943 944 int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val; 945 if (axis < 0 || axis >= input_rank) { 946 return errors::InvalidArgument( 947 "Dimension (", dimension_val, ") must be in the range [", -input_rank, 948 ", ", input_rank, "), where ", input_rank, 949 " is the number of dimensions in the input."); 950 } 951 952 // Return the input shape without the dimension being reduced. 953 std::vector<DimensionHandle> dims; 954 for (int i = 0; i < input_rank; ++i) { 955 if (axis != i) { 956 dims.emplace_back(c->Dim(input_shape, i)); 957 } 958 } 959 c->set_output(0, c->MakeShape(dims)); 960 return Status::OK(); 961 } 962 963 } // namespace 964 965 REGISTER_OP("ArgMax") 966 .Input("input: T") 967 .Input("dimension: Tidx") 968 .Output("output: output_type") 969 .Attr("T: numbertype") 970 .Attr("Tidx: {int32, int64} = DT_INT32") 971 .Attr("output_type: {int32, int64} = DT_INT64") 972 .SetShapeFn(ArgOpShape); 973 974 REGISTER_OP("ArgMin") 975 .Input("input: T") 976 .Input("dimension: Tidx") 977 .Output("output: output_type") 978 .Attr("T: numbertype") 979 .Attr("Tidx: {int32, int64} = DT_INT32") 980 .Attr("output_type: {int32, int64} = DT_INT64") 981 .SetShapeFn(ArgOpShape); 982 983 namespace { 984 985 Status SegmentReductionShapeFn(InferenceContext* c) { 986 ShapeHandle data_shape; 987 ShapeHandle segment_ids_shape; 988 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 989 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape)); 990 991 ShapeHandle subshape; 992 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 993 994 ShapeHandle out; 995 TF_RETURN_IF_ERROR( 996 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); 997 c->set_output(0, out); 998 return Status::OK(); 999 } 1000 1001 Status SparseSegmentReductionShapeFn(InferenceContext* c) { 1002 ShapeHandle data_shape; 1003 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 1004 1005 ShapeHandle indices_shape; 1006 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 1007 1008 ShapeHandle segment_ids_shape; 1009 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); 1010 1011 // indices and segment_ids should merge cleanly. 1012 ShapeHandle unused; 1013 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); 1014 1015 ShapeHandle subshape; 1016 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 1017 1018 ShapeHandle out; 1019 TF_RETURN_IF_ERROR( 1020 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); 1021 c->set_output(0, out); 1022 return Status::OK(); 1023 } 1024 1025 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { 1026 ShapeHandle data_shape; 1027 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 1028 1029 ShapeHandle indices_shape; 1030 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 1031 1032 // indices and segment_ids should merge cleanly. 1033 ShapeHandle unused; 1034 TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused)); 1035 1036 // output_dim0 should be a scalar 1037 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1038 1039 ShapeHandle subshape; 1040 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 1041 1042 const Tensor* dim0 = c->input_tensor(3); 1043 ShapeHandle dim0_shape; 1044 if (dim0 == nullptr) { 1045 // We don't have the value at inference time, so the output 1046 // shape is unknown. 1047 dim0_shape = c->Vector(InferenceContext::kUnknownDim); 1048 } else { 1049 auto dim0_value = dim0->scalar<int32>()(); 1050 if (dim0_value < 0) { 1051 return errors::InvalidArgument( 1052 "Cannot specify a negative value for output_dim0"); 1053 } 1054 dim0_shape = c->Vector(dim0_value); 1055 } 1056 1057 ShapeHandle out; 1058 TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out)); 1059 c->set_output(0, out); 1060 return Status::OK(); 1061 } 1062 1063 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { 1064 ShapeHandle data_shape; 1065 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 1066 1067 ShapeHandle indices_shape; 1068 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 1069 1070 ShapeHandle segment_ids_shape; 1071 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); 1072 1073 ShapeHandle num_segments_shape; 1074 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape)); 1075 1076 // indices and segment_ids should merge cleanly. 1077 ShapeHandle unused; 1078 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); 1079 1080 ShapeHandle subshape; 1081 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 1082 1083 ShapeHandle out; 1084 const Tensor* dim0 = c->input_tensor(3); 1085 if (dim0 == nullptr) { 1086 // We don't have the value at inference time, so the output 1087 // shape is unknown. 1088 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim), 1089 subshape, &out)); 1090 } else { 1091 auto dim0_value = dim0->scalar<int32>()(); 1092 if (dim0_value < 0) { 1093 return errors::InvalidArgument( 1094 "Cannot specify a negative value for num_segments"); 1095 } 1096 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out)); 1097 } 1098 c->set_output(0, out); 1099 return Status::OK(); 1100 } 1101 1102 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { 1103 ShapeHandle s_data = c->input(0); 1104 ShapeHandle s_segment_ids = c->input(1); 1105 ShapeHandle s_num_segments = c->input(2); 1106 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); 1107 1108 ShapeHandle out; 1109 1110 // Leading dimensions of data must be compatible with dimensions of 1111 // <s_segment_ids>. 1112 if (c->RankKnown(s_segment_ids)) { 1113 TF_RETURN_IF_ERROR( 1114 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); 1115 1116 // Get the value of the num_segments input tensor. 1117 DimensionHandle num_segments_dim; 1118 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); 1119 1120 // Output is {segment_id_rank} + s_data[segment_id_rank:]. 1121 ShapeHandle s_data_suffix; 1122 TF_RETURN_IF_ERROR( 1123 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); 1124 TF_RETURN_IF_ERROR( 1125 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); 1126 } else { 1127 out = c->UnknownShape(); 1128 } 1129 c->set_output(0, out); 1130 return Status::OK(); 1131 } 1132 } // namespace 1133 1134 REGISTER_OP("SegmentSum") 1135 .Input("data: T") 1136 .Input("segment_ids: Tindices") 1137 .Output("output: T") 1138 .Attr("T: numbertype") 1139 .Attr("Tindices: {int32,int64}") 1140 .SetShapeFn(SegmentReductionShapeFn); 1141 1142 REGISTER_OP("SegmentMean") 1143 .Input("data: T") 1144 .Input("segment_ids: Tindices") 1145 .Output("output: T") 1146 .Attr("T: numbertype") 1147 .Attr("Tindices: {int32,int64}") 1148 .SetShapeFn(SegmentReductionShapeFn); 1149 1150 REGISTER_OP("SegmentProd") 1151 .Input("data: T") 1152 .Input("segment_ids: Tindices") 1153 .Output("output: T") 1154 .Attr("T: numbertype") 1155 .Attr("Tindices: {int32,int64}") 1156 .SetShapeFn(SegmentReductionShapeFn); 1157 1158 REGISTER_OP("SegmentMin") 1159 .Input("data: T") 1160 .Input("segment_ids: Tindices") 1161 .Output("output: T") 1162 .Attr("T: realnumbertype") 1163 .Attr("Tindices: {int32,int64}") 1164 .SetShapeFn(SegmentReductionShapeFn); 1165 1166 REGISTER_OP("SegmentMax") 1167 .Input("data: T") 1168 .Input("segment_ids: Tindices") 1169 .Output("output: T") 1170 .Attr("T: realnumbertype") 1171 .Attr("Tindices: {int32,int64}") 1172 .SetShapeFn(SegmentReductionShapeFn); 1173 1174 REGISTER_OP("UnsortedSegmentSum") 1175 .Input("data: T") 1176 .Input("segment_ids: Tindices") 1177 .Input("num_segments: Tnumsegments") 1178 .Output("output: T") 1179 .Attr("T: numbertype") 1180 .Attr("Tindices: {int32,int64}") 1181 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1182 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1183 1184 REGISTER_OP("UnsortedSegmentMax") 1185 .Input("data: T") 1186 .Input("segment_ids: Tindices") 1187 .Input("num_segments: Tnumsegments") 1188 .Output("output: T") 1189 .Attr("T: realnumbertype") 1190 .Attr("Tindices: {int32,int64}") 1191 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1192 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1193 1194 REGISTER_OP("UnsortedSegmentMin") 1195 .Input("data: T") 1196 .Input("segment_ids: Tindices") 1197 .Input("num_segments: Tnumsegments") 1198 .Output("output: T") 1199 .Attr("T: realnumbertype") 1200 .Attr("Tindices: {int32,int64}") 1201 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1202 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1203 1204 REGISTER_OP("UnsortedSegmentProd") 1205 .Input("data: T") 1206 .Input("segment_ids: Tindices") 1207 .Input("num_segments: Tnumsegments") 1208 .Output("output: T") 1209 .Attr("T: numbertype") 1210 .Attr("Tindices: {int32,int64}") 1211 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1212 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1213 1214 REGISTER_OP("SparseSegmentSum") 1215 .Input("data: T") 1216 .Input("indices: Tidx") 1217 .Input("segment_ids: int32") 1218 .Output("output: T") 1219 .Attr("T: realnumbertype") 1220 .Attr("Tidx: {int32, int64} = DT_INT32") 1221 .SetShapeFn(SparseSegmentReductionShapeFn); 1222 1223 REGISTER_OP("SparseSegmentSumWithNumSegments") 1224 .Input("data: T") 1225 .Input("indices: Tidx") 1226 .Input("segment_ids: int32") 1227 .Input("num_segments: Tnumsegments") 1228 .Output("output: T") 1229 .Attr("T: realnumbertype") 1230 .Attr("Tidx: {int32, int64} = DT_INT32") 1231 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1232 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1233 1234 REGISTER_OP("SparseSegmentMean") 1235 .Input("data: T") 1236 .Input("indices: Tidx") 1237 .Input("segment_ids: int32") 1238 .Output("output: T") 1239 .Attr("T: {float, double}") 1240 .Attr("Tidx: {int32, int64} = DT_INT32") 1241 .SetShapeFn(SparseSegmentReductionShapeFn); 1242 1243 REGISTER_OP("SparseSegmentMeanWithNumSegments") 1244 .Input("data: T") 1245 .Input("indices: Tidx") 1246 .Input("segment_ids: int32") 1247 .Input("num_segments: Tnumsegments") 1248 .Output("output: T") 1249 .Attr("T: {float, double}") 1250 .Attr("Tidx: {int32, int64} = DT_INT32") 1251 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1252 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1253 1254 REGISTER_OP("SparseSegmentMeanGrad") 1255 .Input("grad: T") 1256 .Input("indices: Tidx") 1257 .Input("segment_ids: int32") 1258 .Input("output_dim0: int32") 1259 .Output("output: T") 1260 .Attr("T: {float, double}") 1261 .Attr("Tidx: {int32, int64} = DT_INT32") 1262 .SetShapeFn(SparseSegmentReductionGradShapeFn); 1263 1264 REGISTER_OP("SparseSegmentSqrtN") 1265 .Input("data: T") 1266 .Input("indices: Tidx") 1267 .Input("segment_ids: int32") 1268 .Output("output: T") 1269 .Attr("T: {float, double}") 1270 .Attr("Tidx: {int32, int64} = DT_INT32") 1271 .SetShapeFn(SparseSegmentReductionShapeFn); 1272 1273 REGISTER_OP("SparseSegmentSqrtNWithNumSegments") 1274 .Input("data: T") 1275 .Input("indices: Tidx") 1276 .Input("segment_ids: int32") 1277 .Input("num_segments: Tnumsegments") 1278 .Output("output: T") 1279 .Attr("T: {float, double}") 1280 .Attr("Tidx: {int32, int64} = DT_INT32") 1281 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1282 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1283 1284 REGISTER_OP("SparseSegmentSqrtNGrad") 1285 .Input("grad: T") 1286 .Input("indices: Tidx") 1287 .Input("segment_ids: int32") 1288 .Input("output_dim0: int32") 1289 .Output("output: T") 1290 .Attr("T: {float, double}") 1291 .Attr("Tidx: {int32, int64} = DT_INT32") 1292 .SetShapeFn(SparseSegmentReductionGradShapeFn); 1293 1294 REGISTER_OP("All") 1295 .Input("input: bool") 1296 .Input("reduction_indices: Tidx") 1297 .Output("output: bool") 1298 .Attr("keep_dims: bool = false") 1299 .Attr("Tidx: {int32, int64} = DT_INT32") 1300 .SetShapeFn(shape_inference::ReductionShape); 1301 1302 REGISTER_OP("Any") 1303 .Input("input: bool") 1304 .Input("reduction_indices: Tidx") 1305 .Attr("keep_dims: bool = false") 1306 .Output("output: bool") 1307 .Attr("Tidx: {int32, int64} = DT_INT32") 1308 .SetShapeFn(shape_inference::ReductionShape); 1309 1310 // -------------------------------------------------------------------------- 1311 1312 namespace { 1313 1314 template <typename T> 1315 Status RangeSize(const Tensor* start_t, const Tensor* limit_t, 1316 const Tensor* delta_t, InferenceContext* const c) { 1317 T start = start_t->scalar<T>()(); 1318 T limit = limit_t->scalar<T>()(); 1319 T delta = delta_t->scalar<T>()(); 1320 if (start > limit && delta > 0) { 1321 return errors::InvalidArgument( 1322 "Requires start <= limit when delta > 0: ", start, "/", limit); 1323 } 1324 if (start < limit && delta < 0) { 1325 return errors::InvalidArgument( 1326 "Requires start >= limit when delta < 0: ", start, "/", limit); 1327 } 1328 if (delta == 0) { 1329 return errors::InvalidArgument("Requires delta != 0"); 1330 } 1331 1332 int64 size = 1333 (std::is_integral<T>::value 1334 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) 1335 : std::ceil(std::abs((limit - start) / delta))); 1336 c->set_output(0, c->Vector(size)); 1337 return Status::OK(); 1338 } 1339 1340 } // namespace 1341 1342 REGISTER_OP("Range") 1343 .Input("start: Tidx") 1344 .Input("limit: Tidx") 1345 .Input("delta: Tidx") 1346 .Output("output: Tidx") 1347 .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32") 1348 .SetShapeFn([](InferenceContext* c) { 1349 ShapeHandle unused; 1350 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), 1351 " for 'start'"); 1352 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), 1353 " for 'limit'"); 1354 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), 1355 " for 'delta'"); 1356 const Tensor* start_t = c->input_tensor(0); 1357 const Tensor* limit_t = c->input_tensor(1); 1358 const Tensor* delta_t = c->input_tensor(2); 1359 DataType dtype; 1360 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype)); 1361 if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) { 1362 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1363 return Status::OK(); 1364 } 1365 if (dtype == DT_INT32) { 1366 return RangeSize<int32>(start_t, limit_t, delta_t, c); 1367 } else if (dtype == DT_INT64) { 1368 return RangeSize<int64>(start_t, limit_t, delta_t, c); 1369 } else if (dtype == DT_FLOAT) { 1370 return RangeSize<float>(start_t, limit_t, delta_t, c); 1371 } else { 1372 return RangeSize<double>(start_t, limit_t, delta_t, c); 1373 } 1374 return Status::OK(); 1375 }); 1376 1377 REGISTER_OP("LinSpace") 1378 .Input("start: T") 1379 .Input("stop: T") 1380 .Input("num: Tidx") 1381 .Output("output: T") 1382 .Attr("T: {bfloat16, float, double}") 1383 .Attr("Tidx: {int32, int64} = DT_INT32") 1384 .SetShapeFn([](InferenceContext* c) { 1385 ShapeHandle unused; 1386 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), 1387 " for 'start'"); 1388 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), 1389 " for 'stop'"); 1390 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), 1391 " for 'num'"); 1392 const Tensor* num_t = c->input_tensor(2); 1393 if (num_t == nullptr) { 1394 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1395 return Status::OK(); 1396 } 1397 1398 int64 num; 1399 if (num_t->dtype() == DT_INT32) { 1400 num = num_t->scalar<int32>()(); 1401 } else { 1402 num = num_t->scalar<int64>()(); 1403 } 1404 if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num); 1405 c->set_output(0, c->Vector(num)); 1406 return Status::OK(); 1407 }); 1408 1409 REGISTER_OP("Complex") 1410 .Input("real: T") 1411 .Input("imag: T") 1412 .Output("out: Tout") 1413 .Attr("T: {float, double} = DT_FLOAT") 1414 .Attr("Tout: {complex64, complex128} = DT_COMPLEX64") 1415 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 1416 1417 REGISTER_OP("Real") 1418 .Input("input: T") 1419 .Output("output: Tout") 1420 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1421 .Attr("Tout: {float, double} = DT_FLOAT") 1422 .SetShapeFn(shape_inference::UnchangedShape); 1423 1424 REGISTER_OP("Imag") 1425 .Input("input: T") 1426 .Output("output: Tout") 1427 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1428 .Attr("Tout: {float, double} = DT_FLOAT") 1429 .SetShapeFn(shape_inference::UnchangedShape); 1430 1431 REGISTER_OP("Angle") 1432 .Input("input: T") 1433 .Output("output: Tout") 1434 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1435 .Attr("Tout: {float, double} = DT_FLOAT") 1436 .SetShapeFn(shape_inference::UnchangedShape); 1437 1438 REGISTER_OP("Conj") 1439 .Input("input: T") 1440 .Output("output: T") 1441 .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64") 1442 .SetShapeFn([](InferenceContext* c) { 1443 c->set_output(0, c->input(0)); 1444 auto* handle_data = c->input_handle_shapes_and_types(0); 1445 if (handle_data != nullptr) { 1446 c->set_output_handle_shapes_and_types(0, *handle_data); 1447 } 1448 return Status::OK(); 1449 }); 1450 1451 // -------------------------------------------------------------------------- 1452 1453 REGISTER_OP("Cross") 1454 .Input("a: T") 1455 .Input("b: T") 1456 .Output("product: T") 1457 .Attr("T: realnumbertype") 1458 .SetShapeFn([](InferenceContext* c) { 1459 ShapeHandle a_shape; 1460 ShapeHandle b_shape; 1461 // * Input rank >= 1. 1462 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape)); 1463 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape)); 1464 1465 // * Both inputs have the same shape. 1466 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape)); 1467 1468 // * input_shape[-1] == 3. 1469 if (c->RankKnown(a_shape)) { 1470 int rank = c->Rank(a_shape); 1471 auto dim = c->Dim(a_shape, rank - 1); 1472 TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim)); 1473 } 1474 c->set_output(0, a_shape); 1475 return Status::OK(); 1476 }); 1477 1478 // -------------------------------------------------------------------------- 1479 1480 REGISTER_OP("HistogramFixedWidth") 1481 .Input("values: T") 1482 .Input("value_range: T") 1483 .Input("nbins: int32") 1484 .Output("out: dtype") 1485 .Attr("T: {int32, int64, float32, float64}") 1486 .Attr("dtype: {int32, int64} = DT_INT32") 1487 .SetShapeFn([](InferenceContext* c) { 1488 // value_range should be a vector. 1489 ShapeHandle value_range_shape; 1490 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape)); 1491 // value_range should have two elements. 1492 DimensionHandle unused; 1493 TF_RETURN_IF_ERROR( 1494 c->WithValue(c->Dim(value_range_shape, 0), 2, &unused)); 1495 // nbins should be a scalar. 1496 ShapeHandle nbins_shape; 1497 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape)); 1498 1499 // If nbins is available, set the shape from nbins. 1500 const Tensor* nbins_input = c->input_tensor(2); 1501 if (nbins_input != nullptr) { 1502 int64 nbins; 1503 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins)); 1504 // nbins has to be positive. 1505 if (nbins <= 0) { 1506 return errors::InvalidArgument("Requires nbins > 0: ", nbins); 1507 } 1508 c->set_output(0, c->Vector(nbins)); 1509 } else { 1510 c->set_output(0, c->UnknownShapeOfRank(1)); 1511 } 1512 return Status::OK(); 1513 }); 1514 1515 REGISTER_OP("Bincount") 1516 .Input("arr: int32") 1517 .Input("size: int32") 1518 .Input("weights: T") 1519 .Attr("T: {int32, int64, float32, float64}") 1520 .Output("bins: T") 1521 .SetShapeFn([](InferenceContext* c) { 1522 ShapeHandle unused; 1523 // The input `size` must be a scalar. 1524 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1525 1526 const Tensor* size_tensor = c->input_tensor(1); 1527 if (size_tensor == nullptr) { 1528 // Return unknown shape if size is not known. 1529 c->set_output(0, c->UnknownShapeOfRank(1)); 1530 return Status::OK(); 1531 } 1532 1533 // Return `[size]` shape if size is known. 1534 int32 size_val = size_tensor->scalar<int32>()(); 1535 if (size_val < 0) { 1536 return errors::InvalidArgument("size (", size_val, 1537 ") must be non-negative"); 1538 } 1539 c->set_output(0, c->MakeShape({size_val})); 1540 return Status::OK(); 1541 }); 1542 1543 REGISTER_OP("Cumsum") 1544 .Input("x: T") 1545 .Input("axis: Tidx") 1546 .Attr("exclusive: bool = false") 1547 .Attr("reverse: bool = false") 1548 .Output("out: T") 1549 .Attr("T: numbertype") 1550 .Attr("Tidx: {int32, int64} = DT_INT32") 1551 .SetShapeFn(shape_inference::UnchangedShape); 1552 1553 REGISTER_OP("Cumprod") 1554 .Input("x: T") 1555 .Input("axis: Tidx") 1556 .Attr("exclusive: bool = false") 1557 .Attr("reverse: bool = false") 1558 .Output("out: T") 1559 .Attr("T: numbertype") 1560 .Attr("Tidx: {int32, int64} = DT_INT32") 1561 .SetShapeFn(shape_inference::UnchangedShape); 1562 1563 REGISTER_OP("QuantizedMatMul") 1564 .Input("a: T1") 1565 .Input("b: T2") 1566 .Input("min_a: float") 1567 .Input("max_a: float") 1568 .Input("min_b: float") 1569 .Input("max_b: float") 1570 .Output("out: Toutput") 1571 .Output("min_out: float") 1572 .Output("max_out: float") 1573 .Attr("T1: quantizedtype") 1574 .Attr("T2: quantizedtype") 1575 .Attr("Toutput: quantizedtype = DT_QINT32") 1576 .Attr("transpose_a: bool = false") 1577 .Attr("transpose_b: bool = false") 1578 .Attr("Tactivation: quantizedtype = DT_QUINT8") 1579 .SetShapeFn([](InferenceContext* c) { 1580 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); 1581 ShapeHandle unused; 1582 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1583 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1584 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1585 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 1586 1587 c->set_output(1, c->Scalar()); 1588 c->set_output(2, c->Scalar()); 1589 return Status::OK(); 1590 }); 1591 1592 REGISTER_OP("QuantizedMul") 1593 .Input("x: T1") 1594 .Input("y: T2") 1595 .Input("min_x: float") 1596 .Input("max_x: float") 1597 .Input("min_y: float") 1598 .Input("max_y: float") 1599 .Output("z: Toutput") 1600 .Output("min_z: float") 1601 .Output("max_z: float") 1602 .Attr("T1: quantizedtype") 1603 .Attr("T2: quantizedtype") 1604 .Attr("Toutput: quantizedtype = DT_QINT32") 1605 .SetIsCommutative() 1606 .SetShapeFn([](InferenceContext* c) { 1607 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); 1608 c->set_output(1, c->Scalar()); 1609 c->set_output(2, c->Scalar()); 1610 return Status::OK(); 1611 }); 1612 1613 REGISTER_OP("QuantizedAdd") 1614 .Input("x: T1") 1615 .Input("y: T2") 1616 .Input("min_x: float") 1617 .Input("max_x: float") 1618 .Input("min_y: float") 1619 .Input("max_y: float") 1620 .Output("z: Toutput") 1621 .Output("min_z: float") 1622 .Output("max_z: float") 1623 .Attr("T1: quantizedtype") 1624 .Attr("T2: quantizedtype") 1625 .Attr("Toutput: quantizedtype = DT_QINT32") 1626 .SetIsCommutative() 1627 .SetShapeFn([](InferenceContext* c) { 1628 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); 1629 // min_x, max_x, min_y, max_y should be scalar. 1630 ShapeHandle unused; 1631 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1632 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1633 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1634 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 1635 1636 c->set_output(1, c->Scalar()); 1637 c->set_output(2, c->Scalar()); 1638 return Status::OK(); 1639 }); 1640 1641 REGISTER_OP("QuantizeDownAndShrinkRange") 1642 .Input("input: Tinput") 1643 .Input("input_min: float") 1644 .Input("input_max: float") 1645 .Output("output: out_type") 1646 .Output("output_min: float") 1647 .Output("output_max: float") 1648 .Attr("Tinput: quantizedtype") 1649 .Attr("out_type: quantizedtype") 1650 .SetShapeFn([](InferenceContext* c) { 1651 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1652 ShapeHandle unused; 1653 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1654 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1655 c->set_output(1, c->Scalar()); 1656 c->set_output(2, c->Scalar()); 1657 return Status::OK(); 1658 }); 1659 1660 REGISTER_OP("Requantize") 1661 .Input("input: Tinput") 1662 .Input("input_min: float") 1663 .Input("input_max: float") 1664 .Input("requested_output_min: float") 1665 .Input("requested_output_max: float") 1666 .Output("output: out_type") 1667 .Output("output_min: float") 1668 .Output("output_max: float") 1669 .Attr("Tinput: quantizedtype") 1670 .Attr("out_type: quantizedtype") 1671 .SetShapeFn([](InferenceContext* c) { 1672 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1673 ShapeHandle unused; 1674 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1675 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1676 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1677 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1678 c->set_output(1, c->Scalar()); 1679 c->set_output(2, c->Scalar()); 1680 return Status::OK(); 1681 }); 1682 1683 REGISTER_OP("CompareAndBitpack") 1684 .Input("input: T") 1685 .Input("threshold: T") 1686 .Output("output: uint8") 1687 .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}") 1688 .SetShapeFn([](InferenceContext* c) { 1689 ShapeHandle input; 1690 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 1691 ShapeHandle unused; 1692 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1693 ShapeHandle output = input; 1694 if (c->RankKnown(input)) { 1695 int rank = c->Rank(input); 1696 auto inner_dim = c->Dim(input, rank - 1); 1697 DimensionHandle inferred_dim; 1698 TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8, 1699 /* evenly_divisible */ true, 1700 &inferred_dim)); 1701 TF_RETURN_IF_ERROR( 1702 c->ReplaceDim(output, rank - 1, inferred_dim, &output)); 1703 } 1704 c->set_output(0, output); 1705 1706 return Status::OK(); 1707 }); 1708 1709 REGISTER_OP("RequantizationRange") 1710 .Input("input: Tinput") 1711 .Input("input_min: float") 1712 .Input("input_max: float") 1713 .Output("output_min: float") 1714 .Output("output_max: float") 1715 .Attr("Tinput: quantizedtype") 1716 .SetShapeFn([](InferenceContext* c) { 1717 ShapeHandle unused; 1718 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1719 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1720 c->set_output(0, c->Scalar()); 1721 c->set_output(1, c->Scalar()); 1722 return Status::OK(); 1723 }); 1724 1725 // -------------------------------------------------------------------------- 1726 1727 REGISTER_OP("Bucketize") 1728 .Input("input: T") 1729 .Output("output: int32") 1730 .Attr("T: {int32, int64, float, double}") 1731 .Attr("boundaries: list(float)") 1732 .SetShapeFn(shape_inference::UnchangedShape); 1733 1734 REGISTER_OP("ClipByValue") 1735 .Input("t: T") 1736 .Input("clip_value_min: T") 1737 .Input("clip_value_max: T") 1738 .Output("output: T") 1739 .Attr("T: numbertype") 1740 .SetShapeFn(shape_inference::UnchangedShape); 1741 1742 #ifdef INTEL_MKL 1743 REGISTER_OP("_MklAddN") 1744 .Input("inputs: N * T") 1745 .Input("mkl_input: N * uint8") 1746 .Output("sum: T") 1747 .Output("mkl_sum: uint8") 1748 .Attr("N: int >= 1") 1749 .Attr("T: numbertype") 1750 .SetIsCommutative() 1751 .SetIsAggregate() 1752 .SetShapeFn([](InferenceContext* c) { 1753 ShapeHandle cur = c->input(c->num_inputs() - 1); 1754 for (int i = c->num_inputs() - 2; i >= 0; --i) { 1755 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 1756 "From merging shape ", i, 1757 " with other shapes."); 1758 } 1759 c->set_output(0, cur); 1760 return Status::OK(); 1761 }) 1762 .Doc(R"doc( 1763 Add two input tensors element wise using mkl kernel sum. 1764 inputs: Must all be the same size and shape. 1765 )doc"); 1766 1767 #endif // INTEL_MKL 1768 1769 REGISTER_OP("RequantizePerChannel") 1770 .Input("input: T") 1771 .Input("input_min: float") 1772 .Input("input_max: float") 1773 .Input("requested_output_min: float") 1774 .Input("requested_output_max: float") 1775 .Output("output: out_type") 1776 .Output("output_min: float") 1777 .Output("output_max: float") 1778 .Attr("T: quantizedtype = DT_QINT32") 1779 .Attr("out_type: quantizedtype = DT_QUINT8") 1780 .SetShapeFn([](InferenceContext* c) { 1781 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1782 ShapeHandle unused; 1783 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 1784 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 1785 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1786 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1787 c->set_output(1, c->Scalar()); 1788 c->set_output(2, c->Scalar()); 1789 return Status::OK(); 1790 }); 1791 REGISTER_OP("RequantizationRangePerChannel") 1792 .Input("input: T") 1793 .Input("input_min: float") 1794 .Input("input_max: float") 1795 .Output("output_min: float") 1796 .Output("output_max: float") 1797 .Attr("T: quantizedtype = DT_QINT32") 1798 .Attr("clip_value_max: float") 1799 .SetShapeFn([](InferenceContext* c) { 1800 ShapeHandle unused; 1801 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 1802 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 1803 c->set_output(0, c->Scalar()); 1804 c->set_output(1, c->Scalar()); 1805 return Status::OK(); 1806 }); 1807 1808 REGISTER_OP("NextAfter") 1809 .Attr("T: {float64, float32} = DT_FLOAT") 1810 .Input("x1: T") 1811 .Input("x2: T") 1812 .Output("output: T") 1813 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 1814 1815 } // namespace tensorflow 1816