1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/numeric_op.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 23 using shape_inference::DimensionHandle; 24 using shape_inference::InferenceContext; 25 using shape_inference::ShapeHandle; 26 27 REGISTER_OP("AddN") 28 .Input("inputs: N * T") 29 .Output("sum: T") 30 .Attr("N: int >= 1") 31 .Attr("T: {numbertype, variant}") 32 .SetIsCommutative() 33 .SetIsAggregate() 34 .SetShapeFn([](InferenceContext* c) { 35 ShapeHandle cur = c->input(c->num_inputs() - 1); 36 for (int i = c->num_inputs() - 2; i >= 0; --i) { 37 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 38 "From merging shape ", i, 39 " with other shapes."); 40 } 41 c->set_output(0, cur); 42 return Status::OK(); 43 }); 44 45 // -------------------------------------------------------------------------- 46 47 // Note that the following operator is just a placeholder and has no 48 // associated kernel. The code in accumulate_n_optimizer.cc replaces 49 // this placeholder with a graph of operators that do have kernels. 50 // The Python code that generates instances of this op is currently in 51 // contrib/framework/python/ops/accumulate_n_v2.py 52 REGISTER_OP("AccumulateNV2") 53 .Input("inputs: N * T") 54 .Output("sum: T") 55 .Attr("N: int >= 1") 56 .Attr("T: numbertype") 57 .Attr("shape: shape") 58 .SetIsCommutative() 59 .SetIsAggregate() 60 .SetShapeFn(shape_inference::ExplicitShape); 61 62 // -------------------------------------------------------------------------- 63 64 REGISTER_OP("BatchMatMul") 65 .Input("x: T") 66 .Input("y: T") 67 .Output("output: T") 68 .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}") 69 .Attr("adj_x: bool = false") 70 .Attr("adj_y: bool = false") 71 .SetShapeFn([](InferenceContext* c) { 72 ShapeHandle a_shape; 73 ShapeHandle b_shape; 74 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); 75 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape)); 76 77 // Determine output rows and cols. 78 bool adj_x; 79 bool adj_y; 80 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x)); 81 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y)); 82 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2); 83 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1); 84 85 // Batch dims match between inputs. 86 ShapeHandle a_batch_dims; 87 ShapeHandle b_batch_dims; 88 ShapeHandle batch_dims; 89 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); 90 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); 91 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); 92 93 // Assert inner dims match. 94 DimensionHandle unused; 95 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1), 96 c->Dim(b_shape, adj_y ? -1 : -2), &unused)); 97 98 ShapeHandle out; 99 TF_RETURN_IF_ERROR(c->Concatenate( 100 batch_dims, c->Matrix(output_rows, output_cols), &out)); 101 c->set_output(0, out); 102 return Status::OK(); 103 }); 104 105 // -------------------------------------------------------------------------- 106 // Casting Ops 107 // 108 // NOTE: Only a smaller number of types are supported by 109 // Cast. The exact casting rule is TBD. The current 110 // implementation uses C++ static cast rules for numeric 111 // types, which may be changed in the future. 112 REGISTER_OP("Cast") 113 .Input("x: SrcT") 114 .Output("y: DstT") 115 .Attr("SrcT: type") 116 .Attr("DstT: type") 117 .SetShapeFn(shape_inference::UnchangedShape); 118 119 REGISTER_OP("_HostCast") 120 .Input("x: SrcT") 121 .Output("y: DstT") 122 .Attr("SrcT: type") 123 .Attr("DstT: type") 124 .SetShapeFn(shape_inference::UnchangedShape) 125 .Doc(R"doc( 126 Cast x of type SrcT to y of DstT. 127 128 _HostCast requires its input and produces its output in host memory. 129 )doc"); 130 131 // -------------------------------------------------------------------------- 132 133 REGISTER_OP("Abs") 134 .Input("x: T") 135 .Output("y: T") 136 .Attr("T: {half, bfloat16, float, double, int32, int64}") 137 .SetShapeFn(shape_inference::UnchangedShape); 138 139 REGISTER_OP("ComplexAbs") 140 .Input("x: T") 141 .Output("y: Tout") 142 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 143 .Attr("Tout: {float, double} = DT_FLOAT") 144 .SetShapeFn(shape_inference::UnchangedShape); 145 146 // Declares cwise unary operations signature: 't -> 't 147 #define UNARY() \ 148 Input("x: T") \ 149 .Output("y: T") \ 150 .Attr( \ 151 "T: {half, bfloat16, float, double, int32, int64, complex64, " \ 152 "complex128}") \ 153 .SetShapeFn(shape_inference::UnchangedShape) 154 155 #define UNARY_REAL() \ 156 Input("x: T") \ 157 .Output("y: T") \ 158 .Attr("T: {half, bfloat16, float, double}") \ 159 .SetShapeFn(shape_inference::UnchangedShape) 160 161 #define UNARY_COMPLEX() \ 162 Input("x: T") \ 163 .Output("y: T") \ 164 .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \ 165 .SetShapeFn(shape_inference::UnchangedShape) 166 167 #define UNARY_GRADIENT_COMPLEX() \ 168 Input("y: T") \ 169 .Input("dy: T") \ 170 .Output("z: T") \ 171 .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \ 172 .SetShapeFn(shape_inference::UnchangedShape) 173 174 REGISTER_OP("Neg").UNARY(); 175 176 REGISTER_OP("Inv").UNARY(); 177 178 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX(); 179 180 REGISTER_OP("Reciprocal").UNARY(); 181 182 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX(); 183 184 REGISTER_OP("Square").UNARY(); 185 186 REGISTER_OP("Sqrt").UNARY_COMPLEX(); 187 188 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX(); 189 190 REGISTER_OP("Rsqrt").UNARY_COMPLEX(); 191 192 REGISTER_OP("Round").UNARY(); 193 194 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX(); 195 196 REGISTER_OP("Exp").UNARY_COMPLEX(); 197 198 REGISTER_OP("Expm1").UNARY_COMPLEX(); 199 200 REGISTER_OP("Log").UNARY_COMPLEX(); 201 202 REGISTER_OP("Log1p").UNARY_COMPLEX(); 203 204 REGISTER_OP("Sinh").UNARY_COMPLEX(); 205 206 REGISTER_OP("Cosh").UNARY_COMPLEX(); 207 208 REGISTER_OP("Tanh").UNARY_COMPLEX(); 209 210 REGISTER_OP("Asinh").UNARY_COMPLEX(); 211 212 REGISTER_OP("Acosh").UNARY_COMPLEX(); 213 214 REGISTER_OP("Atanh").UNARY_COMPLEX(); 215 216 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX(); 217 218 REGISTER_OP("Lgamma").UNARY_REAL(); 219 220 REGISTER_OP("Digamma").UNARY_REAL(); 221 222 REGISTER_OP("Erf").UNARY_REAL(); 223 224 REGISTER_OP("Erfc").UNARY_REAL(); 225 226 REGISTER_OP("Sigmoid").UNARY_COMPLEX(); 227 228 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX(); 229 230 REGISTER_OP("Sin").UNARY_COMPLEX(); 231 232 REGISTER_OP("Cos").UNARY_COMPLEX(); 233 234 REGISTER_OP("Tan").UNARY(); 235 236 REGISTER_OP("Asin").UNARY(); 237 238 REGISTER_OP("Acos").UNARY(); 239 240 REGISTER_OP("Atan").UNARY(); 241 242 #undef UNARY 243 #undef UNARY_REAL 244 #undef UNARY_COMPLEX 245 246 REGISTER_OP("IsNan") 247 .Input("x: T") 248 .Output("y: bool") 249 .Attr("T: {half, bfloat16, float, double}") 250 .SetShapeFn(shape_inference::UnchangedShape); 251 252 REGISTER_OP("IsInf") 253 .Input("x: T") 254 .Output("y: bool") 255 .Attr("T: {half, bfloat16, float, double}") 256 .SetShapeFn(shape_inference::UnchangedShape); 257 258 REGISTER_OP("IsFinite") 259 .Input("x: T") 260 .Output("y: bool") 261 .Attr("T: {half, bfloat16, float, double}") 262 .SetShapeFn(shape_inference::UnchangedShape); 263 264 REGISTER_OP("Sign") 265 .Input("x: T") 266 .Output("y: T") 267 .Attr( 268 "T: {half, bfloat16, float, double, int32, int64, complex64, " 269 "complex128}") 270 .SetShapeFn(shape_inference::UnchangedShape); 271 272 REGISTER_OP("Floor") 273 .Input("x: T") 274 .Output("y: T") 275 .Attr("T: {half, bfloat16, float, double}") 276 .SetShapeFn(shape_inference::UnchangedShape); 277 278 REGISTER_OP("Ceil") 279 .Input("x: T") 280 .Output("y: T") 281 .Attr("T: {half, bfloat16, float, double}") 282 .SetShapeFn(shape_inference::UnchangedShape); 283 284 REGISTER_OP("Rint") 285 .Input("x: T") 286 .Output("y: T") 287 .Attr("T: {bfloat16, float, double}") 288 .SetShapeFn(shape_inference::UnchangedShape); 289 290 // Declares cwise binary operations signature: 't, 't -> 't. 291 292 #define BINARY_MORE() \ 293 Input("x: T").Input("y: T").Output("z: T").Attr( \ 294 "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \ 295 "int64, complex64, complex128}") 296 297 #define BINARY_FEWER() \ 298 Input("x: T").Input("y: T").Output("z: T").Attr( \ 299 "T: {half, bfloat16, float, double, int32, int64, complex64, " \ 300 "complex128}") 301 302 REGISTER_OP("Add") 303 .Input("x: T") 304 .Input("y: T") 305 .Output("z: T") 306 .Attr( 307 "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, " 308 "complex64, complex128, string}") 309 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 310 311 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to 312 // use AddV2 (b/68646025). 313 REGISTER_OP("AddV2") 314 .Input("x: T") 315 .Input("y: T") 316 .Output("z: T") 317 .Attr( 318 "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, " 319 "complex64, complex128}") 320 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 321 .SetIsAggregate() 322 .SetIsCommutative(); 323 324 REGISTER_OP("_MklAdd") 325 .Input("x: T") 326 .Input("y: T") 327 .Input("mkl_x: uint8") 328 .Input("mkl_y: uint8") 329 .Output("z: T") 330 .Output("mkl_z: uint8") 331 .Attr( 332 "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " 333 "complex128, string}") 334 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 335 .Doc(R"doc( 336 Returns x + y element-wise. 337 338 *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting 339 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 340 )doc"); 341 342 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn( 343 shape_inference::BroadcastBinaryOpShapeFn); 344 345 REGISTER_OP("_MklSub") 346 .BINARY_FEWER() 347 .Input("mkl_x: uint8") 348 .Input("mkl_y: uint8") 349 .Output("mkl_z: uint8") 350 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 351 .Doc(R"doc( 352 Returns x - y element-wise. 353 354 *NOTE*: `Sub` supports broadcasting. More about broadcasting 355 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 356 )doc"); 357 358 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn( 359 shape_inference::BroadcastBinaryOpShapeFn); 360 361 REGISTER_OP("_MklMul") 362 .BINARY_MORE() 363 .Input("mkl_x: uint8") 364 .Input("mkl_y: uint8") 365 .Output("mkl_z: uint8") 366 .SetIsCommutative() 367 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 368 .Doc(R"doc( 369 Returns x * y element-wise. 370 371 *NOTE*: `Mul` supports broadcasting. More about broadcasting 372 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 373 )doc"); 374 375 REGISTER_OP("Div").BINARY_MORE().SetShapeFn( 376 shape_inference::BroadcastBinaryOpShapeFn); 377 378 REGISTER_OP("FloorDiv") 379 .BINARY_MORE() 380 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 381 382 REGISTER_OP("TruncateDiv") 383 .BINARY_MORE() 384 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 385 386 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn( 387 shape_inference::BroadcastBinaryOpShapeFn); 388 389 REGISTER_OP("SquaredDifference") 390 .BINARY_FEWER() 391 .SetIsCommutative() 392 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 393 394 REGISTER_OP("_MklSquaredDifference") 395 .BINARY_FEWER() 396 .Input("mkl_x: uint8") 397 .Input("mkl_y: uint8") 398 .Output("mkl_z: uint8") 399 .SetIsCommutative() 400 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 401 .Doc(R"doc( 402 Returns (x - y)(x - y) element-wise. 403 404 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting 405 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 406 )doc"); 407 408 #undef BINARY_FEWER 409 #undef BINARY_MORE 410 411 REGISTER_OP("Maximum") 412 .Input("x: T") 413 .Input("y: T") 414 .Output("z: T") 415 .Attr("T: {half, bfloat16, float, double, int32, int64}") 416 .SetIsCommutative() 417 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 418 419 REGISTER_OP("_MklMaximum") 420 .Input("x: T") 421 .Input("y: T") 422 .Input("mkl_x: uint8") 423 .Input("mkl_y: uint8") 424 .Output("z: T") 425 .Output("mkl_z: uint8") 426 .Attr("T: {half, float, double, int32, int64}") 427 .SetIsCommutative() 428 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 429 .Doc(R"doc( 430 Returns the max of x and y (i.e. x > y ? x : y) element-wise. 431 432 *NOTE*: `Maximum` supports broadcasting. More about broadcasting 433 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 434 )doc"); 435 436 REGISTER_OP("Minimum") 437 .Input("x: T") 438 .Input("y: T") 439 .Output("z: T") 440 .Attr("T: {half, bfloat16, float, double, int32, int64}") 441 .SetIsCommutative() 442 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 443 444 REGISTER_OP("Mod") 445 .Input("x: T") 446 .Input("y: T") 447 .Output("z: T") 448 .Attr("T: {int32, int64, bfloat16, float, double}") 449 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 450 451 REGISTER_OP("FloorMod") 452 .Input("x: T") 453 .Input("y: T") 454 .Output("z: T") 455 .Attr("T: {int32, int64, bfloat16, float, double}") 456 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 457 458 REGISTER_OP("TruncateMod") 459 .Input("x: T") 460 .Input("y: T") 461 .Output("z: T") 462 .Attr("T: {int32, int64, bfloat16, float, double}") 463 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 464 465 REGISTER_OP("Pow") 466 .Input("x: T") 467 .Input("y: T") 468 .Output("z: T") 469 .Attr( 470 "T: {half, bfloat16, float, double, int32, int64, complex64, " 471 "complex128}") 472 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 473 474 REGISTER_OP("Igammac") 475 .Input("a: T") 476 .Input("x: T") 477 .Output("z: T") 478 .Attr("T: {float, double}") 479 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 480 481 REGISTER_OP("Igamma") 482 .Input("a: T") 483 .Input("x: T") 484 .Output("z: T") 485 .Attr("T: {float, double}") 486 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 487 488 REGISTER_OP("Zeta") 489 .Input("x: T") 490 .Input("q: T") 491 .Output("z: T") 492 .Attr("T: {float, double}") 493 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 494 495 REGISTER_OP("Polygamma") 496 .Input("a: T") 497 .Input("x: T") 498 .Output("z: T") 499 .Attr("T: {float, double}") 500 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 501 502 REGISTER_OP("Atan2") 503 .Input("y: T") 504 .Input("x: T") 505 .Output("z: T") 506 .Attr("T: {bfloat16, float, double}") 507 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 508 509 REGISTER_OP("Betainc") 510 .Input("a: T") 511 .Input("b: T") 512 .Input("x: T") 513 .Output("z: T") 514 .Attr("T: {float, double}") 515 .SetShapeFn([](InferenceContext* c) { 516 const int num_inputs = 3; 517 ShapeHandle output = c->UnknownShape(); 518 int num_scalars = 0; 519 ShapeHandle some_non_scalar; 520 for (int i = 0; i < num_inputs; ++i) { 521 ShapeHandle in = c->input(i); 522 if (!c->RankKnown(in)) { 523 some_non_scalar = in; 524 // An input with unknown rank could be either a scalar (to be 525 // broadcast) or some other shape. 526 } else if (c->Rank(in) == 0) { 527 // Input is a scalar, it will be broadcast to the output shape. 528 ++num_scalars; 529 } else { 530 TF_RETURN_IF_ERROR(c->Merge(output, in, &output)); 531 some_non_scalar = output; 532 } 533 } 534 535 if (num_scalars == num_inputs - 1) { 536 // If all but one input is known to be a scalar, then output is the 537 // remaining input. 538 output = some_non_scalar; 539 } else if (num_scalars == num_inputs) { 540 // If all are scalars, output is scalar; pick the first one arbitrarily. 541 output = c->input(0); 542 } 543 544 c->set_output(0, output); 545 return Status::OK(); 546 }); 547 548 // -------------------------------------------------------------------------- 549 550 // Declares cwise binary comparison operations signature: 't, 't -> bool, 551 // where 't has a natural total order. 552 #define COMPARISON() \ 553 Input("x: T") \ 554 .Input("y: T") \ 555 .Output("z: bool") \ 556 .Attr("T: realnumbertype") \ 557 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 558 559 REGISTER_OP("Less").COMPARISON(); 560 561 REGISTER_OP("LessEqual").COMPARISON(); 562 563 REGISTER_OP("Greater").COMPARISON(); 564 565 REGISTER_OP("GreaterEqual").COMPARISON(); 566 567 #undef COMPARISON 568 569 // -------------------------------------------------------------------------- 570 571 #define EQUALITY_COMPARISON() \ 572 Input("x: T") \ 573 .Input("y: T") \ 574 .Output("z: bool") \ 575 .SetIsCommutative() \ 576 .Attr( \ 577 "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \ 578 "int64, complex64, quint8, qint8, qint32, string, bool, " \ 579 "complex128}") \ 580 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 581 582 REGISTER_OP("Equal").EQUALITY_COMPARISON(); 583 584 REGISTER_OP("NotEqual").EQUALITY_COMPARISON(); 585 586 #undef EQUALITY_COMPARISON 587 588 REGISTER_OP("ApproximateEqual") 589 .Input("x: T") 590 .Input("y: T") 591 .Output("z: bool") 592 .SetIsCommutative() 593 .Attr("T: numbertype") 594 .Attr("tolerance: float = 0.00001") 595 .SetShapeFn(shape_inference::UnchangedShape); 596 597 // -------------------------------------------------------------------------- 598 599 REGISTER_OP("LogicalNot") 600 .Input("x: bool") 601 .Output("y: bool") 602 .SetShapeFn(shape_inference::UnchangedShape); 603 604 #define BINARY_LOGICAL() \ 605 Input("x: bool") \ 606 .Input("y: bool") \ 607 .Output("z: bool") \ 608 .SetIsCommutative() \ 609 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) 610 611 REGISTER_OP("LogicalAnd").BINARY_LOGICAL(); 612 613 REGISTER_OP("LogicalOr").BINARY_LOGICAL(); 614 615 #undef BINARY_LOGICAL 616 617 // -------------------------------------------------------------------------- 618 619 REGISTER_OP("Select") 620 .Input("condition: bool") 621 .Input("t: T") 622 .Input("e: T") 623 .Output("output: T") 624 .Attr("T: type") 625 .SetShapeFn([](InferenceContext* c) { 626 auto* handle_data_1 = c->input_handle_shapes_and_types(1); 627 auto* handle_data_2 = c->input_handle_shapes_and_types(2); 628 // Merge handle shape and dtype if applicable. 629 if (handle_data_1 != nullptr && handle_data_2 != nullptr) { 630 const auto size = handle_data_1->size(); 631 std::vector<shape_inference::ShapeAndType> merged_handle_data(size); 632 if (size != handle_data_2->size()) { 633 return errors::InvalidArgument( 634 "Trying to merge handles pointing to different numbers of " 635 "tensors."); 636 } 637 638 for (int i = 0; i < size; ++i) { 639 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; 640 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; 641 if (s1.dtype != s2.dtype) { 642 // TODO(apassos) resolve this in the manner of b/32476923 643 return errors::InvalidArgument( 644 "Trying to merge handles pointing to different dtypes."); 645 } 646 merged_handle_data[i].dtype = s1.dtype; 647 TF_RETURN_IF_ERROR( 648 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); 649 } 650 651 c->set_output_handle_shapes_and_types(0, merged_handle_data); 652 } 653 654 // The inputs 'then' and 'else' must have the same shape. 655 ShapeHandle data = c->input(1); 656 ShapeHandle other = c->input(2); 657 TF_RETURN_IF_ERROR(c->Merge(data, other, &data)); 658 659 // The input 'cond' must either have the same shape as 'then' and 660 // 'else', or be a vector if 'then' and 'else' are at least vectors. 661 ShapeHandle cond = c->input(0); 662 663 if (!c->RankKnown(cond) || !c->RankKnown(data)) { 664 c->set_output(0, data); 665 return Status::OK(); 666 } 667 668 // rank of shape and data is known. 669 670 const int32 cond_rank = c->Rank(cond); 671 const int32 data_rank = c->Rank(data); 672 673 if (cond_rank == 0) { 674 // The rank of 'cond' is a scalar. 675 // t and e can have any shape. 676 c->set_output(0, data); 677 return Status::OK(); 678 } 679 680 if (cond_rank != 1) { 681 // If 'cond' is not a vector, and not a scalar, 682 // then shape must match 'then' and 'else' 683 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); 684 c->set_output(0, data); 685 return Status::OK(); 686 } 687 688 if (data_rank == 0) { 689 // if 'then' and 'else' are scalar also the cond must be 690 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); 691 c->set_output(0, data); 692 return Status::OK(); 693 } 694 695 if (cond_rank == 1) { 696 // if the cond is a vector and the 'then' is not a scalar, 697 // the first dimension of 'then' and 'else' 698 TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond)); 699 c->set_output(0, data); 700 return Status::OK(); 701 } 702 703 c->set_output(0, data); 704 705 return Status::OK(); 706 }); 707 708 // -------------------------------------------------------------------------- 709 710 REGISTER_OP("MatMul") 711 .Input("a: T") 712 .Input("b: T") 713 .Output("product: T") 714 .Attr("transpose_a: bool = false") 715 .Attr("transpose_b: bool = false") 716 .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}") 717 .SetShapeFn(shape_inference::MatMulShape); 718 719 REGISTER_OP("SparseMatMul") 720 .Input("a: Ta") 721 .Input("b: Tb") 722 .Output("product: float") 723 .Attr("transpose_a: bool = false") 724 .Attr("transpose_b: bool = false") 725 .Attr("a_is_sparse: bool = false") 726 .Attr("b_is_sparse: bool = false") 727 .Attr("Ta: {float, bfloat16} = DT_FLOAT") 728 .Attr("Tb: {float, bfloat16} = DT_FLOAT") 729 .SetShapeFn(shape_inference::MatMulShape); 730 731 // -------------------------------------------------------------------------- 732 733 // For operations where the output is a reduction function along some 734 // dimensions of the input. 735 REGISTER_OP("Sum") 736 .Input("input: T") 737 .Input("reduction_indices: Tidx") 738 .Output("output: T") 739 .Attr("keep_dims: bool = false") 740 .Attr("T: numbertype") 741 .Attr("Tidx: {int32, int64} = DT_INT32") 742 .SetShapeFn(shape_inference::ReductionShape); 743 744 REGISTER_OP("Mean") 745 .Input("input: T") 746 .Input("reduction_indices: Tidx") 747 .Output("output: T") 748 .Attr("keep_dims: bool = false") 749 .Attr("T: numbertype") 750 .Attr("Tidx: {int32, int64} = DT_INT32") 751 .SetShapeFn(shape_inference::ReductionShape); 752 753 REGISTER_OP("Prod") 754 .Input("input: T") 755 .Input("reduction_indices: Tidx") 756 .Output("output: T") 757 .Attr("keep_dims: bool = false") 758 .Attr("T: numbertype") 759 .Attr("Tidx: {int32, int64} = DT_INT32") 760 .SetShapeFn(shape_inference::ReductionShape); 761 762 REGISTER_OP("Min") 763 .Input("input: T") 764 .Input("reduction_indices: Tidx") 765 .Output("output: T") 766 .Attr("keep_dims: bool = false") 767 .Attr("T: numbertype") 768 .Attr("Tidx: {int32, int64} = DT_INT32") 769 .SetShapeFn(shape_inference::ReductionShape); 770 771 REGISTER_OP("Max") 772 .Input("input: T") 773 .Input("reduction_indices: Tidx") 774 .Output("output: T") 775 .Attr("keep_dims: bool = false") 776 .Attr("T: numbertype") 777 .Attr("Tidx: {int32, int64} = DT_INT32") 778 .SetShapeFn(shape_inference::ReductionShape); 779 780 namespace { 781 782 Status ArgOpShape(shape_inference::InferenceContext* c) { 783 ShapeHandle dimension_shape; 784 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape)); 785 786 ShapeHandle input_shape = c->input(0); 787 if (!c->RankKnown(input_shape)) { 788 return shape_inference::UnknownShape(c); 789 } 790 791 const int32 input_rank = c->Rank(input_shape); 792 if (input_rank <= 1) { 793 // Reducing a scalar/vector must return a scalar. 794 return shape_inference::ScalarShape(c); 795 } 796 797 const Tensor* dim_t = c->input_tensor(1); 798 if (dim_t == nullptr) { 799 // We don't know the value of the dimension, but we 800 // know the rank of the input, so return the correct 801 // rank with unknown dimensions. 802 std::vector<DimensionHandle> dims(input_rank - 1); 803 for (int i = 0; i < dims.size(); ++i) { 804 dims[i] = c->UnknownDim(); 805 } 806 807 c->set_output(0, c->MakeShape(dims)); 808 return Status::OK(); 809 } 810 811 int64 dimension_val; 812 if (dim_t->dtype() == DT_INT32) { 813 dimension_val = dim_t->scalar<int32>()(); 814 } else { 815 dimension_val = dim_t->scalar<int64>()(); 816 } 817 818 int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val; 819 if (axis < 0 || axis >= input_rank) { 820 return errors::InvalidArgument( 821 "Dimension (", dimension_val, ") must be in the range [", -input_rank, 822 ", ", input_rank, "), where ", input_rank, 823 " is the number of dimensions in the input."); 824 } 825 826 // Return the input shape without the dimension being reduced. 827 std::vector<DimensionHandle> dims; 828 for (int i = 0; i < input_rank; ++i) { 829 if (axis != i) { 830 dims.emplace_back(c->Dim(input_shape, i)); 831 } 832 } 833 c->set_output(0, c->MakeShape(dims)); 834 return Status::OK(); 835 } 836 837 } // namespace 838 839 REGISTER_OP("ArgMax") 840 .Input("input: T") 841 .Input("dimension: Tidx") 842 .Output("output: output_type") 843 .Attr("T: numbertype") 844 .Attr("Tidx: {int32, int64} = DT_INT32") 845 .Attr("output_type: {int32, int64} = DT_INT64") 846 .SetShapeFn(ArgOpShape); 847 848 REGISTER_OP("ArgMin") 849 .Input("input: T") 850 .Input("dimension: Tidx") 851 .Output("output: output_type") 852 .Attr("T: numbertype") 853 .Attr("Tidx: {int32, int64} = DT_INT32") 854 .Attr("output_type: {int32, int64} = DT_INT64") 855 .SetShapeFn(ArgOpShape); 856 857 namespace { 858 859 Status SegmentReductionShapeFn(InferenceContext* c) { 860 ShapeHandle data_shape; 861 ShapeHandle segment_ids_shape; 862 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 863 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape)); 864 865 ShapeHandle subshape; 866 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 867 868 ShapeHandle out; 869 TF_RETURN_IF_ERROR( 870 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); 871 c->set_output(0, out); 872 return Status::OK(); 873 } 874 875 Status SparseSegmentReductionShapeFn(InferenceContext* c) { 876 ShapeHandle data_shape; 877 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 878 879 ShapeHandle indices_shape; 880 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 881 882 ShapeHandle segment_ids_shape; 883 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); 884 885 // indices and segment_ids should merge cleanly. 886 ShapeHandle unused; 887 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); 888 889 ShapeHandle subshape; 890 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 891 892 ShapeHandle out; 893 TF_RETURN_IF_ERROR( 894 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); 895 c->set_output(0, out); 896 return Status::OK(); 897 } 898 899 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { 900 ShapeHandle data_shape; 901 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 902 903 ShapeHandle indices_shape; 904 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 905 906 // indices and segment_ids should merge cleanly. 907 ShapeHandle unused; 908 TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused)); 909 910 // output_dim0 should be a scalar 911 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 912 913 ShapeHandle subshape; 914 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 915 916 const Tensor* dim0 = c->input_tensor(3); 917 ShapeHandle dim0_shape; 918 if (dim0 == nullptr) { 919 // We don't have the value at inference time, so the output 920 // shape is unknown. 921 dim0_shape = c->Vector(InferenceContext::kUnknownDim); 922 } else { 923 auto dim0_value = dim0->scalar<int32>()(); 924 if (dim0_value < 0) { 925 return errors::InvalidArgument( 926 "Cannot specify a negative value for output_dim0"); 927 } 928 dim0_shape = c->Vector(dim0_value); 929 } 930 931 ShapeHandle out; 932 TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out)); 933 c->set_output(0, out); 934 return Status::OK(); 935 } 936 937 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { 938 ShapeHandle data_shape; 939 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); 940 941 ShapeHandle indices_shape; 942 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); 943 944 ShapeHandle segment_ids_shape; 945 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); 946 947 ShapeHandle num_segments_shape; 948 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape)); 949 950 // indices and segment_ids should merge cleanly. 951 ShapeHandle unused; 952 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); 953 954 ShapeHandle subshape; 955 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); 956 957 ShapeHandle out; 958 const Tensor* dim0 = c->input_tensor(3); 959 if (dim0 == nullptr) { 960 // We don't have the value at inference time, so the output 961 // shape is unknown. 962 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim), 963 subshape, &out)); 964 } else { 965 auto dim0_value = dim0->scalar<int32>()(); 966 if (dim0_value < 0) { 967 return errors::InvalidArgument( 968 "Cannot specify a negative value for num_segments"); 969 } 970 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out)); 971 } 972 c->set_output(0, out); 973 return Status::OK(); 974 } 975 976 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) { 977 ShapeHandle s_data = c->input(0); 978 ShapeHandle s_segment_ids = c->input(1); 979 ShapeHandle s_num_segments = c->input(2); 980 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments)); 981 982 ShapeHandle out; 983 984 // Leading dimensions of data must be compatible with dimensions of 985 // <s_segment_ids>. 986 if (c->RankKnown(s_segment_ids)) { 987 TF_RETURN_IF_ERROR( 988 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids)); 989 990 // Get the value of the num_segments input tensor. 991 DimensionHandle num_segments_dim; 992 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim)); 993 994 // Output is {segment_id_rank} + s_data[segment_id_rank:]. 995 ShapeHandle s_data_suffix; 996 TF_RETURN_IF_ERROR( 997 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix)); 998 TF_RETURN_IF_ERROR( 999 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out)); 1000 } else { 1001 out = c->UnknownShape(); 1002 } 1003 c->set_output(0, out); 1004 return Status::OK(); 1005 } 1006 } // namespace 1007 1008 REGISTER_OP("SegmentSum") 1009 .Input("data: T") 1010 .Input("segment_ids: Tindices") 1011 .Output("output: T") 1012 .Attr("T: numbertype") 1013 .Attr("Tindices: {int32,int64}") 1014 .SetShapeFn(SegmentReductionShapeFn); 1015 1016 REGISTER_OP("SegmentMean") 1017 .Input("data: T") 1018 .Input("segment_ids: Tindices") 1019 .Output("output: T") 1020 .Attr("T: realnumbertype") 1021 .Attr("Tindices: {int32,int64}") 1022 .SetShapeFn(SegmentReductionShapeFn); 1023 1024 REGISTER_OP("SegmentProd") 1025 .Input("data: T") 1026 .Input("segment_ids: Tindices") 1027 .Output("output: T") 1028 .Attr("T: numbertype") 1029 .Attr("Tindices: {int32,int64}") 1030 .SetShapeFn(SegmentReductionShapeFn); 1031 1032 REGISTER_OP("SegmentMin") 1033 .Input("data: T") 1034 .Input("segment_ids: Tindices") 1035 .Output("output: T") 1036 .Attr("T: realnumbertype") 1037 .Attr("Tindices: {int32,int64}") 1038 .SetShapeFn(SegmentReductionShapeFn); 1039 1040 REGISTER_OP("SegmentMax") 1041 .Input("data: T") 1042 .Input("segment_ids: Tindices") 1043 .Output("output: T") 1044 .Attr("T: realnumbertype") 1045 .Attr("Tindices: {int32,int64}") 1046 .SetShapeFn(SegmentReductionShapeFn); 1047 1048 REGISTER_OP("UnsortedSegmentSum") 1049 .Input("data: T") 1050 .Input("segment_ids: Tindices") 1051 .Input("num_segments: Tnumsegments") 1052 .Output("output: T") 1053 .Attr("T: numbertype") 1054 .Attr("Tindices: {int32,int64}") 1055 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1056 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1057 1058 REGISTER_OP("UnsortedSegmentMax") 1059 .Input("data: T") 1060 .Input("segment_ids: Tindices") 1061 .Input("num_segments: Tnumsegments") 1062 .Output("output: T") 1063 .Attr("T: realnumbertype") 1064 .Attr("Tindices: {int32,int64}") 1065 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1066 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1067 1068 REGISTER_OP("UnsortedSegmentMin") 1069 .Input("data: T") 1070 .Input("segment_ids: Tindices") 1071 .Input("num_segments: Tnumsegments") 1072 .Output("output: T") 1073 .Attr("T: realnumbertype") 1074 .Attr("Tindices: {int32,int64}") 1075 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1076 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1077 1078 REGISTER_OP("UnsortedSegmentProd") 1079 .Input("data: T") 1080 .Input("segment_ids: Tindices") 1081 .Input("num_segments: Tnumsegments") 1082 .Output("output: T") 1083 .Attr("T: realnumbertype") 1084 .Attr("Tindices: {int32,int64}") 1085 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1086 .SetShapeFn(UnsortedSegmentReductionShapeFn); 1087 1088 REGISTER_OP("SparseSegmentSum") 1089 .Input("data: T") 1090 .Input("indices: Tidx") 1091 .Input("segment_ids: int32") 1092 .Output("output: T") 1093 .Attr("T: realnumbertype") 1094 .Attr("Tidx: {int32, int64} = DT_INT32") 1095 .SetShapeFn(SparseSegmentReductionShapeFn); 1096 1097 REGISTER_OP("SparseSegmentSumWithNumSegments") 1098 .Input("data: T") 1099 .Input("indices: Tidx") 1100 .Input("segment_ids: int32") 1101 .Input("num_segments: Tnumsegments") 1102 .Output("output: T") 1103 .Attr("T: realnumbertype") 1104 .Attr("Tidx: {int32, int64} = DT_INT32") 1105 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1106 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1107 1108 REGISTER_OP("SparseSegmentMean") 1109 .Input("data: T") 1110 .Input("indices: Tidx") 1111 .Input("segment_ids: int32") 1112 .Output("output: T") 1113 .Attr("T: {float, double}") 1114 .Attr("Tidx: {int32, int64} = DT_INT32") 1115 .SetShapeFn(SparseSegmentReductionShapeFn); 1116 1117 REGISTER_OP("SparseSegmentMeanWithNumSegments") 1118 .Input("data: T") 1119 .Input("indices: Tidx") 1120 .Input("segment_ids: int32") 1121 .Input("num_segments: Tnumsegments") 1122 .Output("output: T") 1123 .Attr("T: {float, double}") 1124 .Attr("Tidx: {int32, int64} = DT_INT32") 1125 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1126 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1127 1128 REGISTER_OP("SparseSegmentMeanGrad") 1129 .Input("grad: T") 1130 .Input("indices: Tidx") 1131 .Input("segment_ids: int32") 1132 .Input("output_dim0: int32") 1133 .Output("output: T") 1134 .Attr("T: {float, double}") 1135 .Attr("Tidx: {int32, int64} = DT_INT32") 1136 .SetShapeFn(SparseSegmentReductionGradShapeFn); 1137 1138 REGISTER_OP("SparseSegmentSqrtN") 1139 .Input("data: T") 1140 .Input("indices: Tidx") 1141 .Input("segment_ids: int32") 1142 .Output("output: T") 1143 .Attr("T: {float, double}") 1144 .Attr("Tidx: {int32, int64} = DT_INT32") 1145 .SetShapeFn(SparseSegmentReductionShapeFn); 1146 1147 REGISTER_OP("SparseSegmentSqrtNWithNumSegments") 1148 .Input("data: T") 1149 .Input("indices: Tidx") 1150 .Input("segment_ids: int32") 1151 .Input("num_segments: Tnumsegments") 1152 .Output("output: T") 1153 .Attr("T: {float, double}") 1154 .Attr("Tidx: {int32, int64} = DT_INT32") 1155 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 1156 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); 1157 1158 REGISTER_OP("SparseSegmentSqrtNGrad") 1159 .Input("grad: T") 1160 .Input("indices: Tidx") 1161 .Input("segment_ids: int32") 1162 .Input("output_dim0: int32") 1163 .Output("output: T") 1164 .Attr("T: {float, double}") 1165 .Attr("Tidx: {int32, int64} = DT_INT32") 1166 .SetShapeFn(SparseSegmentReductionGradShapeFn); 1167 1168 REGISTER_OP("All") 1169 .Input("input: bool") 1170 .Input("reduction_indices: Tidx") 1171 .Output("output: bool") 1172 .Attr("keep_dims: bool = false") 1173 .Attr("Tidx: {int32, int64} = DT_INT32") 1174 .SetShapeFn(shape_inference::ReductionShape); 1175 1176 REGISTER_OP("Any") 1177 .Input("input: bool") 1178 .Input("reduction_indices: Tidx") 1179 .Attr("keep_dims: bool = false") 1180 .Output("output: bool") 1181 .Attr("Tidx: {int32, int64} = DT_INT32") 1182 .SetShapeFn(shape_inference::ReductionShape); 1183 1184 // -------------------------------------------------------------------------- 1185 1186 namespace { 1187 1188 template <typename T> 1189 Status RangeSize(const Tensor* start_t, const Tensor* limit_t, 1190 const Tensor* delta_t, InferenceContext* const c) { 1191 T start = start_t->scalar<T>()(); 1192 T limit = limit_t->scalar<T>()(); 1193 T delta = delta_t->scalar<T>()(); 1194 if (start > limit && delta > 0) { 1195 return errors::InvalidArgument( 1196 "Requires start <= limit when delta > 0: ", start, "/", limit); 1197 } 1198 if (start < limit && delta < 0) { 1199 return errors::InvalidArgument( 1200 "Requires start >= limit when delta < 0: ", start, "/", limit); 1201 } 1202 if (delta == 0) { 1203 return errors::InvalidArgument("Requires delta != 0"); 1204 } 1205 1206 int64 size = 1207 (std::is_integral<T>::value 1208 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) 1209 : std::ceil(std::abs((limit - start) / delta))); 1210 c->set_output(0, c->Vector(size)); 1211 return Status::OK(); 1212 } 1213 1214 } // namespace 1215 1216 REGISTER_OP("Range") 1217 .Input("start: Tidx") 1218 .Input("limit: Tidx") 1219 .Input("delta: Tidx") 1220 .Output("output: Tidx") 1221 .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32") 1222 .SetShapeFn([](InferenceContext* c) { 1223 ShapeHandle unused; 1224 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), 1225 " for 'start'"); 1226 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), 1227 " for 'limit'"); 1228 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), 1229 " for 'delta'"); 1230 const Tensor* start_t = c->input_tensor(0); 1231 const Tensor* limit_t = c->input_tensor(1); 1232 const Tensor* delta_t = c->input_tensor(2); 1233 DataType dtype; 1234 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype)); 1235 if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) { 1236 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1237 return Status::OK(); 1238 } 1239 if (dtype == DT_INT32) { 1240 return RangeSize<int32>(start_t, limit_t, delta_t, c); 1241 } else if (dtype == DT_INT64) { 1242 return RangeSize<int64>(start_t, limit_t, delta_t, c); 1243 } else if (dtype == DT_FLOAT) { 1244 return RangeSize<float>(start_t, limit_t, delta_t, c); 1245 } else { 1246 return RangeSize<double>(start_t, limit_t, delta_t, c); 1247 } 1248 return Status::OK(); 1249 }); 1250 1251 REGISTER_OP("LinSpace") 1252 .Input("start: T") 1253 .Input("stop: T") 1254 .Input("num: Tidx") 1255 .Output("output: T") 1256 .Attr("T: {bfloat16, float, double}") 1257 .Attr("Tidx: {int32, int64} = DT_INT32") 1258 .SetShapeFn([](InferenceContext* c) { 1259 ShapeHandle unused; 1260 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), 1261 " for 'start'"); 1262 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), 1263 " for 'stop'"); 1264 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), 1265 " for 'num'"); 1266 const Tensor* num_t = c->input_tensor(2); 1267 if (num_t == nullptr) { 1268 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 1269 return Status::OK(); 1270 } 1271 1272 int64 num; 1273 if (num_t->dtype() == DT_INT32) { 1274 num = num_t->scalar<int32>()(); 1275 } else { 1276 num = num_t->scalar<int64>()(); 1277 } 1278 if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num); 1279 c->set_output(0, c->Vector(num)); 1280 return Status::OK(); 1281 }); 1282 1283 REGISTER_OP("Complex") 1284 .Input("real: T") 1285 .Input("imag: T") 1286 .Output("out: Tout") 1287 .Attr("T: {float, double} = DT_FLOAT") 1288 .Attr("Tout: {complex64, complex128} = DT_COMPLEX64") 1289 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 1290 1291 REGISTER_OP("Real") 1292 .Input("input: T") 1293 .Output("output: Tout") 1294 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1295 .Attr("Tout: {float, double} = DT_FLOAT") 1296 .SetShapeFn(shape_inference::UnchangedShape); 1297 1298 REGISTER_OP("Imag") 1299 .Input("input: T") 1300 .Output("output: Tout") 1301 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1302 .Attr("Tout: {float, double} = DT_FLOAT") 1303 .SetShapeFn(shape_inference::UnchangedShape); 1304 1305 REGISTER_OP("Angle") 1306 .Input("input: T") 1307 .Output("output: Tout") 1308 .Attr("T: {complex64, complex128} = DT_COMPLEX64") 1309 .Attr("Tout: {float, double} = DT_FLOAT") 1310 .SetShapeFn(shape_inference::UnchangedShape); 1311 1312 REGISTER_OP("Conj") 1313 .Input("input: T") 1314 .Output("output: T") 1315 .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64") 1316 .SetShapeFn(shape_inference::UnchangedShape); 1317 1318 // -------------------------------------------------------------------------- 1319 1320 REGISTER_OP("Cross") 1321 .Input("a: T") 1322 .Input("b: T") 1323 .Output("product: T") 1324 .Attr("T: realnumbertype") 1325 .SetShapeFn([](InferenceContext* c) { 1326 ShapeHandle a_shape; 1327 ShapeHandle b_shape; 1328 // * Input rank >= 1. 1329 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape)); 1330 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape)); 1331 1332 // * Both inputs have the same shape. 1333 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape)); 1334 1335 // * input_shape[-1] == 3. 1336 if (c->RankKnown(a_shape)) { 1337 int rank = c->Rank(a_shape); 1338 auto dim = c->Dim(a_shape, rank - 1); 1339 TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim)); 1340 } 1341 c->set_output(0, a_shape); 1342 return Status::OK(); 1343 }); 1344 1345 // -------------------------------------------------------------------------- 1346 1347 REGISTER_OP("HistogramFixedWidth") 1348 .Input("values: T") 1349 .Input("value_range: T") 1350 .Input("nbins: int32") 1351 .Output("out: dtype") 1352 .Attr("T: {int32, int64, float32, float64}") 1353 .Attr("dtype: {int32, int64} = DT_INT32") 1354 .SetShapeFn([](InferenceContext* c) { 1355 const Tensor* nbins_input = c->input_tensor(2); 1356 if (nbins_input != nullptr) { 1357 int64 nbins; 1358 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins)); 1359 c->set_output(0, c->Vector(nbins)); 1360 } else { 1361 c->set_output(0, c->UnknownShapeOfRank(1)); 1362 } 1363 return Status::OK(); 1364 }); 1365 1366 REGISTER_OP("Bincount") 1367 .Input("arr: int32") 1368 .Input("size: int32") 1369 .Input("weights: T") 1370 .Attr("T: {int32, int64, float32, float64}") 1371 .Output("bins: T") 1372 .SetShapeFn([](InferenceContext* c) { 1373 c->set_output(0, c->UnknownShapeOfRank(1)); 1374 return Status::OK(); 1375 }); 1376 1377 REGISTER_OP("Cumsum") 1378 .Input("x: T") 1379 .Input("axis: Tidx") 1380 .Attr("exclusive: bool = false") 1381 .Attr("reverse: bool = false") 1382 .Output("out: T") 1383 .Attr("T: numbertype") 1384 .Attr("Tidx: {int32, int64} = DT_INT32") 1385 .SetShapeFn(shape_inference::UnchangedShape); 1386 1387 REGISTER_OP("Cumprod") 1388 .Input("x: T") 1389 .Input("axis: Tidx") 1390 .Attr("exclusive: bool = false") 1391 .Attr("reverse: bool = false") 1392 .Output("out: T") 1393 .Attr("T: numbertype") 1394 .Attr("Tidx: {int32, int64} = DT_INT32") 1395 .SetShapeFn(shape_inference::UnchangedShape); 1396 1397 REGISTER_OP("QuantizedMatMul") 1398 .Input("a: T1") 1399 .Input("b: T2") 1400 .Input("min_a: float") 1401 .Input("max_a: float") 1402 .Input("min_b: float") 1403 .Input("max_b: float") 1404 .Output("out: Toutput") 1405 .Output("min_out: float") 1406 .Output("max_out: float") 1407 .Attr("T1: quantizedtype") 1408 .Attr("T2: quantizedtype") 1409 .Attr("Toutput: quantizedtype = DT_QINT32") 1410 .Attr("transpose_a: bool = false") 1411 .Attr("transpose_b: bool = false") 1412 .Attr("Tactivation: quantizedtype = DT_QUINT8") 1413 .SetShapeFn([](InferenceContext* c) { 1414 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); 1415 ShapeHandle unused; 1416 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1417 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1418 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1419 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 1420 1421 c->set_output(1, c->Scalar()); 1422 c->set_output(2, c->Scalar()); 1423 return Status::OK(); 1424 }); 1425 1426 REGISTER_OP("QuantizedMul") 1427 .Input("x: T1") 1428 .Input("y: T2") 1429 .Input("min_x: float") 1430 .Input("max_x: float") 1431 .Input("min_y: float") 1432 .Input("max_y: float") 1433 .Output("z: Toutput") 1434 .Output("min_z: float") 1435 .Output("max_z: float") 1436 .Attr("T1: quantizedtype") 1437 .Attr("T2: quantizedtype") 1438 .Attr("Toutput: quantizedtype = DT_QINT32") 1439 .SetIsCommutative() 1440 .SetShapeFn([](InferenceContext* c) { 1441 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); 1442 c->set_output(1, c->Scalar()); 1443 c->set_output(2, c->Scalar()); 1444 return Status::OK(); 1445 }); 1446 1447 REGISTER_OP("QuantizedAdd") 1448 .Input("x: T1") 1449 .Input("y: T2") 1450 .Input("min_x: float") 1451 .Input("max_x: float") 1452 .Input("min_y: float") 1453 .Input("max_y: float") 1454 .Output("z: Toutput") 1455 .Output("min_z: float") 1456 .Output("max_z: float") 1457 .Attr("T1: quantizedtype") 1458 .Attr("T2: quantizedtype") 1459 .Attr("Toutput: quantizedtype = DT_QINT32") 1460 .SetIsCommutative() 1461 .SetShapeFn([](InferenceContext* c) { 1462 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); 1463 c->set_output(1, c->Scalar()); 1464 c->set_output(2, c->Scalar()); 1465 return Status::OK(); 1466 }); 1467 1468 REGISTER_OP("QuantizeDownAndShrinkRange") 1469 .Input("input: Tinput") 1470 .Input("input_min: float") 1471 .Input("input_max: float") 1472 .Output("output: out_type") 1473 .Output("output_min: float") 1474 .Output("output_max: float") 1475 .Attr("Tinput: quantizedtype") 1476 .Attr("out_type: quantizedtype") 1477 .SetShapeFn([](InferenceContext* c) { 1478 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1479 ShapeHandle unused; 1480 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1481 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1482 c->set_output(1, c->Scalar()); 1483 c->set_output(2, c->Scalar()); 1484 return Status::OK(); 1485 }); 1486 1487 REGISTER_OP("Requantize") 1488 .Input("input: Tinput") 1489 .Input("input_min: float") 1490 .Input("input_max: float") 1491 .Input("requested_output_min: float") 1492 .Input("requested_output_max: float") 1493 .Output("output: out_type") 1494 .Output("output_min: float") 1495 .Output("output_max: float") 1496 .Attr("Tinput: quantizedtype") 1497 .Attr("out_type: quantizedtype") 1498 .SetShapeFn([](InferenceContext* c) { 1499 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); 1500 ShapeHandle unused; 1501 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1502 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1503 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1504 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 1505 c->set_output(1, c->Scalar()); 1506 c->set_output(2, c->Scalar()); 1507 return Status::OK(); 1508 }); 1509 1510 REGISTER_OP("CompareAndBitpack") 1511 .Input("input: T") 1512 .Input("threshold: T") 1513 .Output("output: uint8") 1514 .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}") 1515 .SetShapeFn([](InferenceContext* c) { 1516 ShapeHandle input; 1517 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 1518 ShapeHandle unused; 1519 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1520 ShapeHandle output = input; 1521 if (c->RankKnown(input)) { 1522 int rank = c->Rank(input); 1523 auto inner_dim = c->Dim(input, rank - 1); 1524 DimensionHandle inferred_dim; 1525 TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8, 1526 /* evenly_divisible */ true, 1527 &inferred_dim)); 1528 TF_RETURN_IF_ERROR( 1529 c->ReplaceDim(output, rank - 1, inferred_dim, &output)); 1530 } 1531 c->set_output(0, output); 1532 1533 return Status::OK(); 1534 }); 1535 1536 REGISTER_OP("RequantizationRange") 1537 .Input("input: Tinput") 1538 .Input("input_min: float") 1539 .Input("input_max: float") 1540 .Output("output_min: float") 1541 .Output("output_max: float") 1542 .Attr("Tinput: quantizedtype") 1543 .SetShapeFn([](InferenceContext* c) { 1544 ShapeHandle unused; 1545 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1546 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1547 c->set_output(0, c->Scalar()); 1548 c->set_output(1, c->Scalar()); 1549 return Status::OK(); 1550 }); 1551 1552 // -------------------------------------------------------------------------- 1553 1554 REGISTER_OP("Bucketize") 1555 .Input("input: T") 1556 .Output("output: int32") 1557 .Attr("T: {int32, int64, float, double}") 1558 .Attr("boundaries: list(float)") 1559 .SetShapeFn(shape_inference::UnchangedShape); 1560 1561 #ifdef INTEL_MKL 1562 REGISTER_OP("_MklAddN") 1563 .Input("inputs: N * T") 1564 .Input("mkl_input: N * uint8") 1565 .Output("sum: T") 1566 .Output("mkl_sum: uint8") 1567 .Attr("N: int >= 1") 1568 .Attr("T: numbertype") 1569 .SetIsCommutative() 1570 .SetIsAggregate() 1571 .SetShapeFn([](InferenceContext* c) { 1572 ShapeHandle cur = c->input(c->num_inputs() - 1); 1573 for (int i = c->num_inputs() - 2; i >= 0; --i) { 1574 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 1575 "From merging shape ", i, 1576 " with other shapes."); 1577 } 1578 c->set_output(0, cur); 1579 return Status::OK(); 1580 }) 1581 .Doc(R"doc( 1582 Add two input tensors element wise using mkl kernel sum. 1583 inputs: Must all be the same size and shape. 1584 )doc"); 1585 1586 #endif // INTEL_MKL 1587 1588 } // namespace tensorflow 1589