1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ 18 19 #include <cmath> 20 #include <type_traits> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/base/casts.h" 24 #include "absl/container/inlined_vector.h" 25 #include "absl/memory/memory.h" 26 #include "absl/meta/type_traits.h" 27 #include "absl/types/optional.h" 28 #include "tensorflow/compiler/xla/array2d.h" 29 #include "tensorflow/compiler/xla/literal_util.h" 30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 32 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 33 #include "tensorflow/compiler/xla/service/shape_inference.h" 34 35 namespace xla { 36 37 // TODO(b/79274244): We'd like these type traits to live inside of 38 // HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that 39 // crashes clang in the frontend. 40 // 41 // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is 42 // a "private" header that's not exposed outside of hlo_evaluator.cc. 43 template <typename T> 44 using is_complex_t = 45 absl::disjunction<std::is_same<T, complex64>, std::is_same<T, complex128>>; 46 47 // ToArithmeticSafeType(T t): 48 // - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed 49 // integer, and 50 // - otherwise returns `t` unchanged. 51 // 52 // It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic 53 // in this type to force 2's complement behavior. 54 template <typename T, 55 typename std::enable_if<std::is_integral<T>::value && 56 std::is_signed<T>::value>::type* = nullptr> 57 typename std::make_unsigned<T>::type ToArithmeticSafeType(T t) { 58 return static_cast<typename std::make_unsigned<T>::type>(t); 59 } 60 template <typename T, 61 typename std::enable_if<!std::is_integral<T>::value || 62 !std::is_signed<T>::value>::type* = nullptr> 63 T ToArithmeticSafeType(T t) { 64 return std::move(t); 65 } 66 67 // Templated DfsHloVisitor for use by HloEvaluator. 68 // 69 // Typically ReturnT here indicates the resulting literal type of each evaluated 70 // Handle* method of a TypedVisitor. There are however a few notable exceptions 71 // to this rule, notably: 72 // - HandleCompare and HandleIsFinite: where the resulting literal type is 73 // always boolean. 74 // - HandleImag and HandleReal: where the resulting literal type is always float 75 // and the operand is always complex, or real in the case of HandleReal. 76 // These operations are handled outside of the parent HloEvaluator handlers 77 // instead of from within TypedVisitor. 78 // 79 // Type params: 80 // - ReturnT: The type of input and output of each operation. 81 // - ElementwiseT: The type in which internal computation are done. 82 // 83 // This a logically a private part of HloEvaluator. It lives in this header 84 // file rather than in hlo_evaluator.cc because we use extern templates and a 85 // bunch of independent cc files to speed up compiling the many instantiations 86 // of this class. 87 template <typename ReturnT, typename ElementwiseT = ReturnT> 88 class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { 89 private: 90 Status UnsupportedTypeError(HloInstruction* instruction) { 91 return InvalidArgument( 92 "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), 93 PrimitiveType_Name(instruction->shape().element_type())); 94 } 95 96 // Get the value in the given literal static_cast as a double. 97 template < 98 typename NativeT, 99 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 100 double GetAsDouble(const Literal& literal, 101 absl::Span<const int64> input_index) { 102 return static_cast<double>(literal.Get<NativeT>(input_index)); 103 } 104 105 // Specialization for complex types. In this case it is not possible to 106 // static_cast value to a double so just CHECK fail. This method is not used 107 // at run-time, but must be available at compile-time to keep the compiler 108 // happy. 109 template < 110 typename NativeT, 111 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 112 double GetAsDouble(const Literal& literal, 113 absl::Span<const int64> input_index) { 114 LOG(FATAL) << "Trying to get complex literal as double: " 115 << literal.ToString(); 116 } 117 118 public: 119 explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} 120 121 // The following higher-order functions convert a function with ElementwiseT 122 // to a function with ReturnT. 123 std::function<ReturnT(ReturnT)> ConvertUnaryFunction( 124 const std::function<ElementwiseT(ElementwiseT)>& unary_op) { 125 return [&unary_op](ReturnT arg) { 126 return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg))); 127 }; 128 } 129 std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction( 130 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& 131 binary_op) { 132 return [&binary_op](ReturnT arg1, ReturnT arg2) { 133 return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1), 134 static_cast<ElementwiseT>(arg2))); 135 }; 136 } 137 std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction( 138 const std::function<ElementwiseT(ElementwiseT, ElementwiseT, 139 ElementwiseT)>& ternary_op) { 140 return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { 141 return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1), 142 static_cast<ElementwiseT>(arg2), 143 static_cast<ElementwiseT>(arg3))); 144 }; 145 } 146 147 Status DefaultAction(HloInstruction* hlo_instruction) override { 148 return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", 149 HloOpcodeString(hlo_instruction->opcode())); 150 } 151 152 template <typename NativeT, 153 typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = 154 nullptr> 155 Status HandleAbs(HloInstruction* abs) { 156 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], 157 ElementWiseUnaryOp(abs, [](NativeT elem_operand) { 158 return elem_operand; 159 })); 160 return Status::OK(); 161 } 162 163 template < 164 typename NativeT, 165 typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr> 166 Status HandleAbs(HloInstruction* abs) { 167 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], 168 ElementWiseUnaryOp(abs, [](NativeT elem_operand) { 169 return std::abs(elem_operand); 170 })); 171 return Status::OK(); 172 } 173 174 template < 175 typename NativeT, 176 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 177 Status HandleAbs(HloInstruction* abs) { 178 const Literal& operand_literal = 179 parent_->GetEvaluatedLiteralFor(abs->operand(0)); 180 TF_ASSIGN_OR_RETURN( 181 parent_->evaluated_[abs], 182 (HloEvaluator::ElementWiseUnaryOpImpl<float, NativeT>( 183 abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, 184 operand_literal))); 185 186 return Status::OK(); 187 } 188 189 Status HandleAbs(HloInstruction* abs) override { 190 // If the operand is of C64 type, the return type of abs will be F32. 191 // However, ElementwiseT would still be the return type, F32, and thus 192 // specifying the ElementwiseT explicitly as C64 is needed below. 193 if (abs->operand(0)->shape().element_type() == C64) { 194 return HandleAbs<complex64>(abs); 195 } else if (abs->operand(0)->shape().element_type() == C128) { 196 return HandleAbs<complex128>(abs); 197 } 198 return HandleAbs<ElementwiseT>(abs); 199 } 200 201 template < 202 typename NativeT, 203 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 204 Status HandleRound(HloInstruction* round) { 205 TF_ASSIGN_OR_RETURN( 206 parent_->evaluated_[round], 207 ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { 208 return std::round(elem_operand); 209 })); 210 return Status::OK(); 211 } 212 213 template < 214 typename NativeT, 215 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 216 Status HandleRound(HloInstruction* round) { 217 return UnsupportedTypeError(round); 218 } 219 220 Status HandleRound(HloInstruction* round) override { 221 return HandleRound<ReturnT>(round); 222 } 223 224 template < 225 typename NativeT, 226 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 227 Status HandleCeil(HloInstruction* ceil) { 228 TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], 229 ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { 230 return std::ceil(elem_operand); 231 })); 232 return Status::OK(); 233 } 234 235 template < 236 typename NativeT, 237 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 238 Status HandleCeil(HloInstruction* ceil) { 239 return UnsupportedTypeError(ceil); 240 } 241 242 Status HandleCeil(HloInstruction* ceil) override { 243 return HandleCeil<ReturnT>(ceil); 244 } 245 246 Status HandleConvert(HloInstruction* convert) override { 247 const HloInstruction* operand = convert->operand(0); 248 TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); 249 TF_ASSIGN_OR_RETURN(Literal result, 250 parent_->GetEvaluatedLiteralFor(operand).Convert( 251 convert->shape().element_type())); 252 parent_->evaluated_[convert] = std::move(result); 253 return Status::OK(); 254 } 255 256 Status HandleBitcastConvert(HloInstruction* convert) override { 257 const HloInstruction* operand = convert->operand(0); 258 TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); 259 TF_ASSIGN_OR_RETURN(Literal result, 260 parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( 261 convert->shape().element_type())); 262 263 parent_->evaluated_[convert] = std::move(result); 264 return Status::OK(); 265 } 266 267 Status HandleExp(HloInstruction* exp) override { 268 TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], 269 ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { 270 return std::exp(elem_operand); 271 })); 272 return Status::OK(); 273 } 274 275 template < 276 typename NativeT, 277 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 278 Status HandleExpm1(HloInstruction* expm1) { 279 TF_ASSIGN_OR_RETURN( 280 parent_->evaluated_[expm1], 281 ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { 282 return std::expm1(elem_operand); 283 })); 284 return Status::OK(); 285 } 286 287 template < 288 typename NativeT, 289 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 290 Status HandleExpm1(HloInstruction* expm1) { 291 return UnsupportedTypeError(expm1); 292 } 293 294 Status HandleExpm1(HloInstruction* floor) override { 295 return HandleExpm1<ReturnT>(floor); 296 } 297 298 template < 299 typename NativeT, 300 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 301 Status HandleFloor(HloInstruction* floor) { 302 TF_ASSIGN_OR_RETURN( 303 parent_->evaluated_[floor], 304 ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { 305 return std::floor(elem_operand); 306 })); 307 return Status::OK(); 308 } 309 310 template < 311 typename NativeT, 312 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 313 Status HandleFloor(HloInstruction* floor) { 314 return UnsupportedTypeError(floor); 315 } 316 317 Status HandleFloor(HloInstruction* floor) override { 318 return HandleFloor<ReturnT>(floor); 319 } 320 321 Status HandleLog(HloInstruction* log) override { 322 TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], 323 ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { 324 return std::log(elem_operand); 325 })); 326 return Status::OK(); 327 } 328 329 template < 330 typename NativeT, 331 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 332 Status HandleLog1p(HloInstruction* log1p) { 333 TF_ASSIGN_OR_RETURN( 334 parent_->evaluated_[log1p], 335 ElementWiseUnaryOp(log1p, [](ElementwiseT elem_operand) { 336 return std::log1p(elem_operand); 337 })); 338 return Status::OK(); 339 } 340 341 template < 342 typename NativeT, 343 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 344 Status HandleLog1p(HloInstruction* log1p) { 345 return UnsupportedTypeError(log1p); 346 } 347 348 Status HandleLog1p(HloInstruction* log1p) override { 349 return HandleLog1p<ReturnT>(log1p); 350 } 351 352 template <typename NativeT, 353 typename std::enable_if< 354 std::is_integral<NativeT>::value && 355 !std::is_same<NativeT, bool>::value>::type* = nullptr> 356 Status HandleNot(HloInstruction* not_) { 357 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 358 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 359 return ~elem_operand; 360 })); 361 return Status::OK(); 362 } 363 364 template <typename NativeT, typename std::enable_if<std::is_floating_point< 365 NativeT>::value>::type* = nullptr> 366 Status HandleNot(HloInstruction* not_) { 367 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 368 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 369 return !elem_operand; 370 })); 371 return Status::OK(); 372 } 373 374 template <typename NativeT, 375 typename std::enable_if<std::is_same<NativeT, bool>::value>::type* = 376 nullptr> 377 Status HandleNot(HloInstruction* not_) { 378 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 379 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 380 return !elem_operand; 381 })); 382 return Status::OK(); 383 } 384 385 template < 386 typename NativeT, 387 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 388 Status HandleNot(HloInstruction* not_) { 389 return UnsupportedTypeError(not_); 390 } 391 392 Status HandleNot(HloInstruction* not_) override { 393 return HandleNot<ElementwiseT>(not_); 394 } 395 396 template <typename NativeT, 397 typename std::enable_if< 398 std::is_signed<NativeT>::value && 399 !std::is_floating_point<NativeT>::value>::type* = nullptr> 400 Status HandleNegate(HloInstruction* negate) { 401 using type = typename std::make_unsigned<NativeT>::type; 402 TF_ASSIGN_OR_RETURN( 403 parent_->evaluated_[negate], 404 ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { 405 return NativeT(-type(elem_operand)); 406 })); 407 return Status::OK(); 408 } 409 410 template <typename NativeT, 411 typename std::enable_if< 412 !std::is_signed<NativeT>::value || 413 std::is_floating_point<NativeT>::value>::type* = nullptr> 414 Status HandleNegate(HloInstruction* negate) { 415 TF_ASSIGN_OR_RETURN( 416 parent_->evaluated_[negate], 417 ElementWiseUnaryOp( 418 negate, [](ElementwiseT elem_operand) { return -elem_operand; })); 419 return Status::OK(); 420 } 421 422 Status HandleNegate(HloInstruction* negate) override { 423 return HandleNegate<ReturnT>(negate); 424 } 425 426 template <typename NativeT, 427 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 428 nullptr> 429 Status HandleSign(HloInstruction* sign) { 430 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], 431 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { 432 return (ElementwiseT(0) < elem_operand) - 433 (elem_operand < ElementwiseT(0)); 434 })); 435 return Status::OK(); 436 } 437 438 template <typename NativeT, 439 typename std::enable_if< 440 std::is_same<NativeT, bfloat16>::value || 441 std::is_same<NativeT, Eigen::half>::value || 442 std::is_floating_point<NativeT>::value>::type* = nullptr> 443 Status HandleSign(HloInstruction* sign) { 444 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], 445 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { 446 return std::isnan(elem_operand) 447 ? elem_operand 448 : std::copysign( 449 elem_operand != ElementwiseT(0), 450 elem_operand); 451 })); 452 return Status::OK(); 453 } 454 455 template < 456 typename NativeT, 457 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 458 Status HandleSign(HloInstruction* sign) { 459 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], 460 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { 461 auto abs_val = std::abs(elem_operand); 462 return 0 == abs_val ? ElementwiseT(0) 463 : elem_operand / abs_val; 464 })); 465 return Status::OK(); 466 } 467 468 Status HandleSign(HloInstruction* sign) override { 469 return HandleSign<ReturnT>(sign); 470 } 471 472 template <typename NativeT, typename std::enable_if<std::is_floating_point< 473 NativeT>::value>::type* = nullptr> 474 Status HandleAtan2(HloInstruction* atan2) { 475 TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], 476 ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, 477 ElementwiseT rhs_elem) { 478 return std::atan2(lhs_elem, rhs_elem); 479 })); 480 return Status::OK(); 481 } 482 483 template <typename NativeT, typename std::enable_if<!std::is_floating_point< 484 NativeT>::value>::type* = nullptr> 485 Status HandleAtan2(HloInstruction* atan2) { 486 return UnsupportedTypeError(atan2); 487 } 488 489 Status HandleAtan2(HloInstruction* atan2) override { 490 return HandleAtan2<ElementwiseT>(atan2); 491 } 492 493 Status HandleTanh(HloInstruction* tanh) override { 494 TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], 495 ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { 496 return std::tanh(elem_operand); 497 })); 498 return Status::OK(); 499 } 500 501 Status HandleMultiply(HloInstruction* multiply) override { 502 TF_ASSIGN_OR_RETURN( 503 parent_->evaluated_[multiply], 504 ElementWiseBinaryOp( 505 multiply, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { 506 return ElementwiseT(ToArithmeticSafeType(lhs_elem) * 507 ToArithmeticSafeType(rhs_elem)); 508 })); 509 return Status::OK(); 510 } 511 512 Status HandleSubtract(HloInstruction* subtract) override { 513 TF_ASSIGN_OR_RETURN( 514 parent_->evaluated_[subtract], 515 ElementWiseBinaryOp( 516 subtract, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { 517 return ElementwiseT(ToArithmeticSafeType(lhs_elem) - 518 ToArithmeticSafeType(rhs_elem)); 519 })); 520 return Status::OK(); 521 } 522 523 Status HandleAdd(HloInstruction* add) override { 524 TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], 525 ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, 526 ElementwiseT rhs_elem) { 527 return ElementwiseT(ToArithmeticSafeType(lhs_elem) + 528 ToArithmeticSafeType(rhs_elem)); 529 })); 530 return Status::OK(); 531 } 532 533 template < 534 typename NativeT, 535 typename std::enable_if<std::is_floating_point<NativeT>::value || 536 is_complex_t<NativeT>::value>::type* = nullptr> 537 Status HandleDivide(HloInstruction* divide) { 538 TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], 539 ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, 540 ElementwiseT rhs_elem) { 541 return lhs_elem / rhs_elem; 542 })); 543 return Status::OK(); 544 } 545 546 template <typename NativeT, 547 typename std::enable_if<std::is_signed<NativeT>::value && 548 std::is_integral<NativeT>::value>::type* = 549 nullptr> 550 Status HandleDivide(HloInstruction* divide) { 551 TF_ASSIGN_OR_RETURN( 552 parent_->evaluated_[divide], 553 ElementWiseBinaryOp( 554 divide, 555 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT { 556 if (rhs_elem == 0) { 557 return static_cast<ElementwiseT>(-1); 558 } 559 if (rhs_elem == -1 && 560 lhs_elem == std::numeric_limits<ElementwiseT>::min()) { 561 return lhs_elem; 562 } 563 return lhs_elem / rhs_elem; 564 })); 565 return Status::OK(); 566 } 567 568 template <typename NativeT, 569 typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = 570 nullptr> 571 Status HandleDivide(HloInstruction* divide) { 572 TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], 573 ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, 574 ElementwiseT rhs_elem) { 575 return rhs_elem == 0 576 ? std::numeric_limits<ElementwiseT>::max() 577 : (lhs_elem / rhs_elem); 578 })); 579 return Status::OK(); 580 } 581 582 Status HandleDivide(HloInstruction* divide) override { 583 return HandleDivide<ElementwiseT>(divide); 584 } 585 586 template <typename NativeT, 587 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 588 nullptr> 589 Status HandleMaximum(HloInstruction* maximum) { 590 TF_ASSIGN_OR_RETURN( 591 parent_->evaluated_[maximum], 592 ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { 593 return std::max(lhs, rhs); 594 })); 595 return Status::OK(); 596 } 597 598 template <typename NativeT, typename std::enable_if<std::is_floating_point< 599 NativeT>::value>::type* = nullptr> 600 Status HandleMaximum(HloInstruction* maximum) { 601 TF_ASSIGN_OR_RETURN( 602 parent_->evaluated_[maximum], 603 ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { 604 return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; 605 })); 606 return Status::OK(); 607 } 608 609 template < 610 typename NativeT, 611 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 612 Status HandleMaximum(HloInstruction* maximum) { 613 return UnsupportedTypeError(maximum); 614 } 615 616 Status HandleMaximum(HloInstruction* maximum) override { 617 return HandleMaximum<ElementwiseT>(maximum); 618 } 619 620 template <typename NativeT, 621 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 622 nullptr> 623 Status HandleMinimum(HloInstruction* minimum) { 624 TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], 625 ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, 626 ElementwiseT rhs_el) { 627 return std::min(lhs_el, rhs_el); 628 })); 629 return Status::OK(); 630 } 631 632 template <typename NativeT, typename std::enable_if<std::is_floating_point< 633 NativeT>::value>::type* = nullptr> 634 Status HandleMinimum(HloInstruction* minimum) { 635 TF_ASSIGN_OR_RETURN( 636 parent_->evaluated_[minimum], 637 ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, 638 ElementwiseT rhs_el) { 639 return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; 640 })); 641 return Status::OK(); 642 } 643 644 template < 645 typename NativeT, 646 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 647 Status HandleMinimum(HloInstruction* minimum) { 648 return UnsupportedTypeError(minimum); 649 } 650 651 Status HandleMinimum(HloInstruction* minimum) override { 652 return HandleMinimum<ElementwiseT>(minimum); 653 } 654 655 Status HandlePower(HloInstruction* power) override { 656 TF_ASSIGN_OR_RETURN( 657 parent_->evaluated_[power], 658 ElementWiseBinaryOp( 659 power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 660 return lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0) 661 ? static_cast<ElementwiseT>(1) 662 : std::pow(lhs_el, rhs_el); 663 })); 664 return Status::OK(); 665 } 666 667 Status HandleSqrt(HloInstruction* sqrt) override { 668 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sqrt], 669 ElementWiseUnaryOp(sqrt, [](ElementwiseT elem_operand) { 670 return std::sqrt(elem_operand); 671 })); 672 return Status::OK(); 673 } 674 675 Status HandleRsqrt(HloInstruction* rsqrt) override { 676 TF_ASSIGN_OR_RETURN( 677 parent_->evaluated_[rsqrt], 678 ElementWiseUnaryOp(rsqrt, [](ElementwiseT elem_operand) { 679 return static_cast<ElementwiseT>(1) / std::sqrt(elem_operand); 680 })); 681 return Status::OK(); 682 } 683 684 template <typename NativeT, typename std::enable_if<std::is_floating_point< 685 NativeT>::value>::type* = nullptr> 686 Status HandleRemainder(HloInstruction* remainder) { 687 TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], 688 ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, 689 ElementwiseT rhs_el) { 690 return std::fmod(lhs_el, rhs_el); 691 })); 692 return Status::OK(); 693 } 694 695 template <typename NativeT, 696 typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = 697 nullptr> 698 Status HandleRemainder(HloInstruction* remainder) { 699 TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], 700 ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, 701 ElementwiseT rhs_el) { 702 return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el); 703 })); 704 return Status::OK(); 705 } 706 707 template <typename NativeT, 708 typename std::enable_if<std::is_signed<NativeT>::value && 709 std::is_integral<NativeT>::value>::type* = 710 nullptr> 711 Status HandleRemainder(HloInstruction* remainder) { 712 TF_ASSIGN_OR_RETURN( 713 parent_->evaluated_[remainder], 714 ElementWiseBinaryOp( 715 remainder, 716 [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT { 717 if (rhs_el == 0) { 718 return lhs_el; 719 } 720 if (rhs_el == -1 && 721 lhs_el == std::numeric_limits<ElementwiseT>::min()) { 722 return 0; 723 } 724 return lhs_el % rhs_el; 725 })); 726 return Status::OK(); 727 } 728 729 template < 730 typename NativeT, 731 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 732 Status HandleRemainder(HloInstruction* remainder) { 733 return UnsupportedTypeError(remainder); 734 } 735 736 Status HandleRemainder(HloInstruction* remainder) override { 737 return HandleRemainder<ElementwiseT>(remainder); 738 } 739 740 template <typename NativeT, 741 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 742 nullptr> 743 Status HandleAnd(HloInstruction* and_) { 744 TF_ASSIGN_OR_RETURN( 745 parent_->evaluated_[and_], 746 ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 747 return lhs_el & rhs_el; 748 })); 749 return Status::OK(); 750 } 751 752 template <typename NativeT, typename std::enable_if<std::is_floating_point< 753 NativeT>::value>::type* = nullptr> 754 Status HandleAnd(HloInstruction* and_) { 755 return UnsupportedTypeError(and_); 756 } 757 758 template < 759 typename NativeT, 760 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 761 Status HandleAnd(HloInstruction* and_) { 762 return UnsupportedTypeError(and_); 763 } 764 765 Status HandleAnd(HloInstruction* and_) override { 766 return HandleAnd<ElementwiseT>(and_); 767 } 768 769 template <typename NativeT, 770 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 771 nullptr> 772 Status HandleOr(HloInstruction* or_) { 773 TF_ASSIGN_OR_RETURN( 774 parent_->evaluated_[or_], 775 ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 776 return lhs_el | rhs_el; 777 })); 778 return Status::OK(); 779 } 780 781 template <typename NativeT, typename std::enable_if<std::is_floating_point< 782 NativeT>::value>::type* = nullptr> 783 Status HandleOr(HloInstruction* or_) { 784 return UnsupportedTypeError(or_); 785 } 786 787 template < 788 typename NativeT, 789 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 790 Status HandleOr(HloInstruction* or_) { 791 return InvalidArgument("Unsupported type for Or"); 792 } 793 794 Status HandleOr(HloInstruction* or_) override { 795 return HandleOr<ElementwiseT>(or_); 796 } 797 798 template <typename NativeT, 799 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 800 nullptr> 801 Status HandleXor(HloInstruction* xor_) { 802 TF_ASSIGN_OR_RETURN( 803 parent_->evaluated_[xor_], 804 ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 805 return lhs_el ^ rhs_el; 806 })); 807 return Status::OK(); 808 } 809 810 template <typename NativeT, typename std::enable_if<std::is_floating_point< 811 NativeT>::value>::type* = nullptr> 812 Status HandleXor(HloInstruction* xor_) { 813 return UnsupportedTypeError(xor_); 814 } 815 816 template < 817 typename NativeT, 818 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 819 Status HandleXor(HloInstruction* xor_) { 820 return UnsupportedTypeError(xor_); 821 } 822 823 Status HandleXor(HloInstruction* xor_) override { 824 return HandleXor<ElementwiseT>(xor_); 825 } 826 827 template <typename NativeT, 828 typename std::enable_if< 829 std::is_integral<NativeT>::value && 830 !std::is_same<NativeT, bool>::value>::type* = nullptr> 831 Status HandleShiftLeft(HloInstruction* shl) { 832 TF_ASSIGN_OR_RETURN( 833 parent_->evaluated_[shl], 834 ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { 835 return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0 836 : (lhs_elem << rhs_elem); 837 })); 838 return Status::OK(); 839 } 840 841 template <typename NativeT, 842 typename std::enable_if<!std::is_integral<NativeT>::value || 843 std::is_same<NativeT, bool>::value>::type* = 844 nullptr> 845 Status HandleShiftLeft(HloInstruction* shift) { 846 return UnsupportedTypeError(shift); 847 } 848 849 Status HandleShiftLeft(HloInstruction* shl) override { 850 return HandleShiftLeft<ElementwiseT>(shl); 851 } 852 template <typename NativeT, 853 typename std::enable_if< 854 std::is_integral<NativeT>::value && 855 !std::is_same<NativeT, bool>::value>::type* = nullptr> 856 Status HandleShiftRightArithmetic(HloInstruction* shr) { 857 typedef typename std::make_signed<NativeT>::type SignedT; 858 TF_ASSIGN_OR_RETURN( 859 parent_->evaluated_[shr], 860 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { 861 SignedT lhs_signed = static_cast<SignedT>(lhs_elem); 862 if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { 863 return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0; 864 } else { 865 return lhs_signed >> rhs_elem; 866 } 867 })); 868 return Status::OK(); 869 } 870 871 template <typename NativeT, 872 typename std::enable_if<!std::is_integral<NativeT>::value || 873 std::is_same<NativeT, bool>::value>::type* = 874 nullptr> 875 Status HandleShiftRightArithmetic(HloInstruction* shift) { 876 return UnsupportedTypeError(shift); 877 } 878 879 Status HandleShiftRightArithmetic(HloInstruction* shra) override { 880 return HandleShiftRightArithmetic<ElementwiseT>(shra); 881 } 882 883 template <typename NativeT, 884 typename std::enable_if< 885 std::is_integral<NativeT>::value && 886 !std::is_same<NativeT, bool>::value>::type* = nullptr> 887 Status HandleShiftRightLogical(HloInstruction* shr) { 888 typedef typename std::make_unsigned<NativeT>::type UnsignedT; 889 TF_ASSIGN_OR_RETURN( 890 parent_->evaluated_[shr], 891 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { 892 // If shift amount is greater than the number of bits, then return 0. 893 if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { 894 return static_cast<NativeT>(0); 895 } 896 return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> 897 rhs_elem); 898 })); 899 return Status::OK(); 900 } 901 902 template <typename NativeT, 903 typename std::enable_if<!std::is_integral<NativeT>::value || 904 std::is_same<NativeT, bool>::value>::type* = 905 nullptr> 906 Status HandleShiftRightLogical(HloInstruction* shift) { 907 return UnsupportedTypeError(shift); 908 } 909 910 Status HandleShiftRightLogical(HloInstruction* shrl) override { 911 return HandleShiftRightLogical<ElementwiseT>(shrl); 912 } 913 914 // Special case for integral type due to MSVC's std::isnan being unable to 915 // handle integral type. 916 template <typename NativeT, 917 typename std::enable_if<!is_complex_t<NativeT>::value && 918 std::is_integral<NativeT>::value>::type* = 919 nullptr> 920 Status HandleClamp(HloInstruction* clamp) { 921 std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> 922 clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { 923 return static_cast<ElementwiseT>( 924 std::min(high, std::max(value, low))); 925 }; 926 TF_ASSIGN_OR_RETURN( 927 parent_->evaluated_[clamp], 928 ElementwiseTernaryOp(clamp, 929 std::move(ConvertTernaryFunction(clamp_op)))); 930 return Status::OK(); 931 } 932 933 template <typename NativeT, 934 typename std::enable_if<!is_complex_t<NativeT>::value && 935 !std::is_integral<NativeT>::value>::type* = 936 nullptr> 937 Status HandleClamp(HloInstruction* clamp) { 938 std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> 939 clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { 940 if (std::isnan(low) || std::isnan(high)) { 941 return static_cast<ElementwiseT>(NAN); 942 } 943 return static_cast<ElementwiseT>( 944 std::min<NativeT>(high, std::max<NativeT>(value, low))); 945 }; 946 TF_ASSIGN_OR_RETURN( 947 parent_->evaluated_[clamp], 948 ElementwiseTernaryOp(clamp, 949 std::move(ConvertTernaryFunction(clamp_op)))); 950 return Status::OK(); 951 } 952 953 template < 954 typename NativeT, 955 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 956 Status HandleClamp(HloInstruction* clamp) { 957 return UnsupportedTypeError(clamp); 958 } 959 960 Status HandleClamp(HloInstruction* clamp) override { 961 return HandleClamp<ElementwiseT>(clamp); 962 } 963 964 Status HandleSelect(HloInstruction* select) override { 965 CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); 966 CHECK(select->shape().IsArray()); 967 std::function<ReturnT(bool, ReturnT, ReturnT)> select_op = 968 [](bool pred, ReturnT on_true, ReturnT on_false) { 969 if (pred) { 970 return on_true; 971 } 972 return on_false; 973 }; 974 TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], 975 ElementwiseTernaryOp(select, std::move(select_op))); 976 return Status::OK(); 977 } 978 979 Status HandleReverse(HloInstruction* reverse) override { 980 const auto result_shape = reverse->shape(); 981 const auto reverse_dimensions = reverse->dimensions(); 982 983 auto operand = reverse->operand(0); 984 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 985 ShapeInference::InferReverseShape(operand->shape(), 986 reverse_dimensions)); 987 988 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 989 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 990 << " but is inferred to be: " 991 << ShapeUtil::HumanString(inferred_return_shape); 992 993 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 994 Literal result(result_shape); 995 996 TF_RETURN_IF_ERROR( 997 result.Populate<ReturnT>([&](absl::Span<const int64> out_index) { 998 std::vector<int64> from_index(out_index.begin(), out_index.end()); 999 for (const int64 dim : reverse_dimensions) { 1000 from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; 1001 } 1002 return operand_literal.Get<ReturnT>(from_index); 1003 })); 1004 1005 parent_->evaluated_[reverse] = std::move(result); 1006 return Status::OK(); 1007 } 1008 1009 Status HandleConvolution(HloInstruction* conv) override { 1010 auto lhs = conv->operand(0); 1011 auto rhs = conv->operand(1); 1012 const auto& window = conv->window(); 1013 const Shape& result_shape = conv->shape(); 1014 const Shape& lhs_shape = lhs->shape(); 1015 const Shape& rhs_shape = rhs->shape(); 1016 1017 TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); 1018 TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); 1019 CHECK(lhs_shape.IsArray()); 1020 CHECK(rhs_shape.IsArray()); 1021 CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); 1022 CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); 1023 1024 const auto& dnums = conv->convolution_dimension_numbers(); 1025 const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); 1026 CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); 1027 CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); 1028 CHECK_GE(num_spatial_dims, 0); 1029 CHECK_EQ(window.dimensions_size(), num_spatial_dims); 1030 1031 const auto lhs_rank = lhs_shape.rank(); 1032 const auto rhs_rank = rhs_shape.rank(); 1033 1034 CHECK_EQ(num_spatial_dims + 2, lhs_rank); 1035 CHECK_EQ(num_spatial_dims + 2, rhs_rank); 1036 1037 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1038 ShapeInference::InferConvolveShape( 1039 lhs_shape, rhs_shape, conv->feature_group_count(), 1040 conv->batch_group_count(), window, dnums)); 1041 CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1042 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 1043 << " but is inferred to be: " 1044 << ShapeUtil::HumanString(inferred_return_shape); 1045 1046 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 1047 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 1048 1049 std::vector<int64> window_dimension_sizes; 1050 for (auto i : dnums.kernel_spatial_dimensions()) { 1051 window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); 1052 } 1053 1054 const Shape& window_shape = 1055 ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); 1056 1057 DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); 1058 DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); 1059 1060 auto lhs_literal_data = lhs_literal.data<ReturnT>(); 1061 auto rhs_literal_data = rhs_literal.data<ReturnT>(); 1062 1063 const int64 feature_group_count = conv->feature_group_count(); 1064 const int64 batch_group_count = conv->batch_group_count(); 1065 1066 auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, 1067 &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, 1068 rhs_literal_data, feature_group_count, 1069 batch_group_count](const absl::Span<const int64> out_index) { 1070 // Dimension number applicable for input (lhs). 1071 const int64 input_batch_dim = dnums.input_batch_dimension(); 1072 const int64 input_z_dim = dnums.input_feature_dimension(); 1073 // Dimension number applicable for kernel (rhs). 1074 const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); 1075 const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); 1076 // Dimension number applicable for output. 1077 const int64 output_batch_dim = dnums.output_batch_dimension(); 1078 const int64 output_z_dim = dnums.output_feature_dimension(); 1079 1080 const int64 input_z_size = 1081 ShapeUtil::GetDimension(lhs_shape, input_z_dim); 1082 1083 const int64 input_batch_size = 1084 ShapeUtil::GetDimension(lhs_shape, input_batch_dim); 1085 1086 const int64 batch_group_size = input_batch_size / batch_group_count; 1087 1088 // The size of an input feature group. 1089 const int64 input_feature_group_size = input_z_size / feature_group_count; 1090 1091 const int64 output_z_size = 1092 ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); 1093 // The output feature dimension is a concatenation of convolution results 1094 // from the different groups. 1095 const int64 output_feature_group_size = 1096 output_z_size / feature_group_count; 1097 1098 // Calculate the group index to which the current output index 1099 // belongs. 1100 const int64 feature_group_index = 1101 out_index[output_z_dim] / output_feature_group_size; 1102 1103 const int64 batch_group_index = out_index[output_z_dim]; 1104 1105 ElementwiseT result_val = static_cast<ElementwiseT>(0); 1106 DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), 1107 0); 1108 1109 // Convolve input feature with kernel. 1110 // The mechanism indexes into the correct LHS (input) and RHS (kernel) 1111 // locations and accumulates multiplications for a given output index. 1112 do { 1113 // Find corresponding spatial dimension index for input (lhs). 1114 int64 lhs_linear_spatial_index = 0; 1115 int64 rhs_linear_spatial_index = 0; 1116 for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { 1117 // Spatial dimension number for input (lhs) and output. 1118 const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); 1119 const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki); 1120 1121 // Calculate lhs (input) index without taking base dilation into 1122 // account. 1123 const auto& window_dim = window.dimensions(ki); 1124 const int64 undilated_index = 1125 out_index[output_spatial_dim] * window_dim.stride() - 1126 window_dim.padding_low() + 1127 rhs_spatial_index[ki] * window_dim.window_dilation(); 1128 // Skip if the lhs (input) index is to be dilated. As an 1129 // optimization, skip this mod if there's no dilation. 1130 if (window_dim.base_dilation() > 1 && 1131 undilated_index % window_dim.base_dilation() != 0) { 1132 goto cnt; 1133 } 1134 1135 // Calculate the actual lhs (input) index after dilation. As an 1136 // optimization, skip this integer divide if there's no dilation. 1137 int64 lhs_spatial_index; 1138 if (window_dim.base_dilation() > 1) { 1139 lhs_spatial_index = undilated_index / window_dim.base_dilation(); 1140 } else { 1141 lhs_spatial_index = undilated_index; 1142 } 1143 1144 // Skip if input index is not in bounds. 1145 if (!(lhs_spatial_index >= 0 && 1146 lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { 1147 goto cnt; 1148 } 1149 1150 lhs_linear_spatial_index += 1151 lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; 1152 rhs_linear_spatial_index += 1153 (window_dim.window_reversal() 1154 ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) 1155 : rhs_spatial_index[ki]) * 1156 rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; 1157 } 1158 1159 for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { 1160 const int64 iz = 1161 feature_group_index * input_feature_group_size + rhs_iz; 1162 1163 int64 lhs_linear_index = lhs_linear_spatial_index; 1164 1165 lhs_linear_index += out_index[output_batch_dim] * 1166 lhs_dim_multipliers[input_batch_dim]; 1167 1168 // We are scraping only the diagonal elements in the resultant 1169 // convolution output when batch_group_count is greater than 1, 1170 // where 1 is the default. No scraping is done in that case. 1171 // This approach works out automatically for 'groups' in batches 1172 // with group_size > 1, because we already descend down the batch 1173 // dimension for the 'output_batch_dim' above. 1174 lhs_linear_index += 1175 ((batch_group_index * batch_group_size) % input_batch_size) * 1176 lhs_dim_multipliers[input_batch_dim]; 1177 1178 lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; 1179 1180 int64 rhs_linear_index = rhs_linear_spatial_index; 1181 1182 rhs_linear_index += out_index[output_z_dim] * 1183 rhs_dim_multipliers[kernel_output_z_dim]; 1184 rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; 1185 1186 result_val += 1187 static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) * 1188 static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]); 1189 } 1190 cnt : {} 1191 } while (IndexUtil::BumpIndices(window_shape, 1192 absl::MakeSpan(rhs_spatial_index))); 1193 1194 return static_cast<ReturnT>(result_val); 1195 }; 1196 1197 Literal result(result_shape); 1198 TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func)); 1199 1200 parent_->evaluated_[conv] = std::move(result); 1201 return Status::OK(); 1202 } 1203 1204 Status HandleDot(HloInstruction* dot) override { 1205 if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && 1206 parent_->use_fast_path_) { 1207 return HandleDot<ReturnT>(dot); 1208 } 1209 return HandleDotSlowPath(dot); 1210 } 1211 1212 template <typename NativeT, typename std::enable_if<std::is_same< 1213 NativeT, float>::value>::type* = nullptr> 1214 Status HandleDot(HloInstruction* dot) { 1215 const HloInstruction* lhs = dot->operand(0); 1216 const HloInstruction* rhs = dot->operand(1); 1217 CHECK(dot->shape().IsArray()); 1218 CHECK(lhs->shape().IsArray()); 1219 CHECK(rhs->shape().IsArray()); 1220 1221 const auto& dnums = dot->dot_dimension_numbers(); 1222 1223 const int64 lhs_rank = lhs->shape().rank(); 1224 const int64 rhs_rank = rhs->shape().rank(); 1225 1226 CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); 1227 CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); 1228 1229 // There must be 1 and only 1 Contracting dimension for lhs and rhs. 1230 const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); 1231 const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); 1232 // Contracted dimension sizes must be the same. 1233 CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), 1234 rhs->shape().dimensions(rhs_contracting_dimension)) 1235 << "lhs contracted dimension: " 1236 << lhs->shape().dimensions(lhs_contracting_dimension) 1237 << " rhs contracted dimension: " 1238 << rhs->shape().dimensions(rhs_contracting_dimension); 1239 1240 // The fast path is for a simple rank 2 dot with default layout operands. 1241 if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && 1242 rhs_contracting_dimension == 0 && 1243 LayoutUtil::Equal(lhs->shape().layout(), 1244 LayoutUtil::GetDefaultLayoutForR2()) && 1245 LayoutUtil::Equal(rhs->shape().layout(), 1246 LayoutUtil::GetDefaultLayoutForR2()) && 1247 LayoutUtil::Equal(dot->shape().layout(), 1248 LayoutUtil::GetDefaultLayoutForR2())) { 1249 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 1250 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 1251 const int64 contracted_dimension_size = 1252 lhs->shape().dimensions(lhs_contracting_dimension); 1253 Array2D<NativeT> lhs_array(lhs->shape().dimensions(0), 1254 contracted_dimension_size); 1255 lhs_array.SetValues(lhs_literal.data<NativeT>()); 1256 Array2D<NativeT> rhs_array(contracted_dimension_size, 1257 rhs->shape().dimensions(1)); 1258 rhs_array.SetValues(rhs_literal.data<NativeT>()); 1259 std::unique_ptr<Array2D<NativeT>> result_array = 1260 HloEvaluator::MatmulArray2D(lhs_array, rhs_array); 1261 Literal result(dot->shape()); 1262 result.PopulateR2FromArray2D(*result_array); 1263 parent_->evaluated_[dot] = std::move(result); 1264 return Status::OK(); 1265 } 1266 return HandleDotSlowPath(dot); 1267 } 1268 1269 template <typename NativeT, typename std::enable_if<!std::is_same< 1270 NativeT, float>::value>::type* = nullptr> 1271 Status HandleDot(HloInstruction* dot) { 1272 return HandleDotSlowPath(dot); 1273 } 1274 1275 Status HandleDotSlowPath(HloInstruction* dot) { 1276 auto lhs = dot->operand(0); 1277 auto rhs = dot->operand(1); 1278 CHECK(dot->shape().IsArray()); 1279 CHECK(lhs->shape().IsArray()); 1280 CHECK(rhs->shape().IsArray()); 1281 1282 const auto& dnums = dot->dot_dimension_numbers(); 1283 1284 const auto lhs_rank = lhs->shape().rank(); 1285 const auto rhs_rank = rhs->shape().rank(); 1286 1287 CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); 1288 CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); 1289 1290 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 1291 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 1292 1293 CHECK_EQ(dnums.lhs_batch_dimensions_size(), 1294 dnums.rhs_batch_dimensions_size()); 1295 1296 DimensionVector lhs_index(lhs_rank); 1297 DimensionVector rhs_index(rhs_rank); 1298 1299 // result_index_locations[i] contains one or two pointers to the locations 1300 // in lhs_index or rhs_index where the i'th result index should go. 1301 absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> 1302 result_index_locations; 1303 result_index_locations.reserve( 1304 (lhs_rank - dnums.lhs_contracting_dimensions_size()) + 1305 (rhs_rank - dnums.rhs_contracting_dimensions_size())); 1306 1307 // The first components in the output shape are the LHS and RHS batch 1308 // dimensions: 1309 for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { 1310 result_index_locations.push_back( 1311 {&lhs_index[dnums.lhs_batch_dimensions(i)], 1312 &rhs_index[dnums.rhs_batch_dimensions(i)]}); 1313 } 1314 1315 // Then we have the LHS and RHS non-contracting dimensions, if any: 1316 for (int64 i = 0; i < lhs_rank; i++) { 1317 if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) && 1318 !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { 1319 result_index_locations.push_back({&lhs_index[i], nullptr}); 1320 } 1321 } 1322 for (int64 i = 0; i < rhs_rank; i++) { 1323 if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) && 1324 !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { 1325 result_index_locations.push_back({&rhs_index[i], nullptr}); 1326 } 1327 } 1328 1329 absl::InlinedVector<int64, kInlineRank> accumulate_index_sizes; 1330 accumulate_index_sizes.reserve(dnums.lhs_contracting_dimensions_size()); 1331 absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> 1332 accumulate_index_locations; 1333 accumulate_index_locations.reserve(dnums.lhs_contracting_dimensions_size()); 1334 for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { 1335 const int64 lhs_dnum = dnums.lhs_contracting_dimensions(i); 1336 const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i); 1337 accumulate_index_locations.push_back( 1338 {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]}); 1339 const int64 dim_size = lhs->shape().dimensions(lhs_dnum); 1340 accumulate_index_sizes.push_back(dim_size); 1341 } 1342 const int64 total_contraction_size = Product(accumulate_index_sizes); 1343 Literal result(dot->shape()); 1344 TF_RETURN_IF_ERROR( 1345 result.Populate<ReturnT>([&](absl::Span<const int64> result_index) { 1346 ElementwiseT result_val = static_cast<ElementwiseT>(0); 1347 1348 for (int64 i = 0; i < result_index.size(); i++) { 1349 *result_index_locations[i].first = result_index[i]; 1350 if (result_index_locations[i].second) { 1351 *result_index_locations[i].second = result_index[i]; 1352 } 1353 } 1354 1355 // Accumulates resulting product along the contracted dimension. 1356 absl::InlinedVector<int64, kInlineRank> accumulate_index( 1357 accumulate_index_sizes.size(), 0); 1358 for (int64 k = 0; k < total_contraction_size; k++) { 1359 for (int64 i = 0; i < accumulate_index_sizes.size(); ++i) { 1360 *(accumulate_index_locations[i].first) = accumulate_index[i]; 1361 *(accumulate_index_locations[i].second) = accumulate_index[i]; 1362 } 1363 1364 result_val += 1365 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) * 1366 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index)); 1367 1368 // If there are no contracting dimension accumulate_index_sizes is 1369 // empty, do not try to count down from -1 to 0 since it is and 1370 // infinite loop. 1371 if (!accumulate_index_sizes.empty()) { 1372 for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { 1373 int64 value = ++accumulate_index[i]; 1374 if (value != accumulate_index_sizes[i]) { 1375 break; 1376 } 1377 accumulate_index[i] = 0; 1378 } 1379 } 1380 } 1381 1382 return static_cast<ReturnT>(result_val); 1383 })); 1384 1385 parent_->evaluated_[dot] = std::move(result); 1386 return Status::OK(); 1387 } 1388 1389 Status HandlePad(HloInstruction* pad) override { 1390 CHECK(pad->operand(0)->shape().IsArray()); 1391 // Padding value must be scalar. 1392 CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); 1393 CHECK_EQ(pad->operand(0)->shape().rank(), 1394 pad->padding_config().dimensions_size()); 1395 1396 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1397 ShapeInference::InferPadShape( 1398 /*operand_shape=*/pad->operand(0)->shape(), 1399 /*padding_value_shape=*/pad->operand(1)->shape(), 1400 /*padding_config=*/pad->padding_config())); 1401 CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) 1402 << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) 1403 << " but is inferred to be: " 1404 << ShapeUtil::HumanString(inferred_return_shape); 1405 1406 // Create new HLO of padded shape with padding value. 1407 ReturnT scalar = 1408 parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); 1409 Literal result(pad->shape()); 1410 TF_RETURN_IF_ERROR(result.Populate<ReturnT>( 1411 [&scalar](absl::Span<const int64> multi_index) { return scalar; })); 1412 1413 const Literal& evaluated_operand = 1414 parent_->GetEvaluatedLiteralFor(pad->operand(0)); 1415 1416 std::vector<int64> input_index(evaluated_operand.shape().rank(), 0); 1417 std::vector<int64> target_index(result.shape().rank(), 0); 1418 1419 // Loop through each element of the operand, assign them to the 1420 // corresponding index of the resulting padded literal. 1421 const PaddingConfig& pad_config = pad->padding_config(); 1422 1423 auto func = [&](absl::Span<const int64> input_index) { 1424 for (auto i = 0; i < input_index.size(); ++i) { 1425 // Interior padding occurs logically before edge padding, so in the case 1426 // of negative edge padding elements are removed from the 1427 // interior-padded operand. 1428 target_index[i] = 1429 pad_config.dimensions(i).edge_padding_low() + 1430 input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); 1431 1432 // Account for negative low and high padding: skip assignment if the 1433 // any target index is out of range. 1434 if (!(target_index[i] >= 0 && 1435 target_index[i] < pad->shape().dimensions(i))) { 1436 return true; 1437 } 1438 } 1439 result.Set<ReturnT>(target_index, 1440 evaluated_operand.Get<ReturnT>(input_index)); 1441 return true; 1442 }; 1443 1444 std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(), 1445 0); 1446 std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1); 1447 1448 ShapeUtil::ForEachIndex( 1449 evaluated_operand.shape(), zero_base, 1450 AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); 1451 1452 parent_->evaluated_[pad] = std::move(result); 1453 return Status::OK(); 1454 } 1455 1456 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { 1457 auto operand = dynamic_slice->operand(0); 1458 auto start_indices = dynamic_slice->operand(1); 1459 auto result_shape = dynamic_slice->shape(); 1460 TF_ASSIGN_OR_RETURN( 1461 auto inferred_return_shape, 1462 ShapeInference::InferDynamicSliceShape( 1463 operand->shape(), 1464 Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(), 1465 dynamic_slice->dynamic_slice_sizes())); 1466 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1467 << "return shape is set to: " << ShapeUtil::HumanString(result_shape) 1468 << " but is inferred to be: " 1469 << ShapeUtil::HumanString(inferred_return_shape); 1470 TF_RET_CHECK( 1471 primitive_util::IsIntegralType(start_indices->shape().element_type())); 1472 1473 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1474 1475 switch (start_indices->shape().element_type()) { 1476 case S32: { 1477 TF_ASSIGN_OR_RETURN( 1478 parent_->evaluated_[dynamic_slice], 1479 DynamicSlice<int32>( 1480 operand_literal, 1481 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), 1482 result_shape)); 1483 } break; 1484 case S64: { 1485 TF_ASSIGN_OR_RETURN( 1486 parent_->evaluated_[dynamic_slice], 1487 DynamicSlice<int64>( 1488 operand_literal, 1489 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), 1490 result_shape)); 1491 } break; 1492 case U32: { 1493 TF_ASSIGN_OR_RETURN( 1494 parent_->evaluated_[dynamic_slice], 1495 DynamicSlice<uint32>( 1496 operand_literal, 1497 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), 1498 result_shape)); 1499 } break; 1500 case U64: { 1501 TF_ASSIGN_OR_RETURN( 1502 parent_->evaluated_[dynamic_slice], 1503 DynamicSlice<uint64>( 1504 operand_literal, 1505 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), 1506 result_shape)); 1507 } break; 1508 default: 1509 LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " 1510 "start_indices: " 1511 << PrimitiveType_Name(start_indices->shape().element_type()); 1512 } 1513 1514 return Status::OK(); 1515 } 1516 1517 Status HandleDynamicUpdateSlice( 1518 HloInstruction* dynamic_update_slice) override { 1519 auto operand = dynamic_update_slice->operand(0); 1520 auto update = dynamic_update_slice->operand(1); 1521 auto start_indices = dynamic_update_slice->operand(2); 1522 auto result_shape = dynamic_update_slice->shape(); 1523 TF_ASSIGN_OR_RETURN( 1524 auto inferred_return_shape, 1525 ShapeInference::InferDynamicUpdateSliceShape( 1526 operand->shape(), update->shape(), 1527 Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice) 1528 ->index_shapes())); 1529 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1530 << "return shape is set to: " << ShapeUtil::HumanString(result_shape) 1531 << " but is inferred to be: " 1532 << ShapeUtil::HumanString(inferred_return_shape); 1533 TF_RET_CHECK( 1534 primitive_util::IsIntegralType(start_indices->shape().element_type())); 1535 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); 1536 1537 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1538 const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); 1539 1540 switch (start_indices->shape().element_type()) { 1541 case S32: { 1542 TF_ASSIGN_OR_RETURN( 1543 parent_->evaluated_[dynamic_update_slice], 1544 DynamicUpdateSlice<int32>( 1545 operand_literal, update_literal, 1546 absl::MakeConstSpan(dynamic_update_slice->operands()) 1547 .subspan(2))); 1548 } break; 1549 case S64: { 1550 TF_ASSIGN_OR_RETURN( 1551 parent_->evaluated_[dynamic_update_slice], 1552 DynamicUpdateSlice<int64>( 1553 operand_literal, update_literal, 1554 absl::MakeConstSpan(dynamic_update_slice->operands()) 1555 .subspan(2))); 1556 } break; 1557 case U32: { 1558 TF_ASSIGN_OR_RETURN( 1559 parent_->evaluated_[dynamic_update_slice], 1560 DynamicUpdateSlice<uint32>( 1561 operand_literal, update_literal, 1562 absl::MakeConstSpan(dynamic_update_slice->operands()) 1563 .subspan(2))); 1564 } break; 1565 case U64: { 1566 TF_ASSIGN_OR_RETURN( 1567 parent_->evaluated_[dynamic_update_slice], 1568 DynamicUpdateSlice<uint64>( 1569 operand_literal, update_literal, 1570 absl::MakeConstSpan(dynamic_update_slice->operands()) 1571 .subspan(2))); 1572 } break; 1573 default: 1574 LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " 1575 "start_indices: " 1576 << PrimitiveType_Name(start_indices->shape().element_type()); 1577 } 1578 1579 return Status::OK(); 1580 } 1581 1582 template <typename NativeT> 1583 StatusOr<Literal> MapImpl(HloInstruction* map) { 1584 auto operands = map->operands(); 1585 HloComputation* computation = map->to_apply(); 1586 1587 Literal result(map->shape()); 1588 1589 HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); 1590 TF_RETURN_IF_ERROR( 1591 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 1592 std::vector<Literal> arg_literals; 1593 arg_literals.reserve(operands.size()); 1594 1595 // Construct scalar literal parameters to be passed to the map 1596 // computation. 1597 for (auto operand : operands) { 1598 const Literal& arg_literal = 1599 parent_->GetEvaluatedLiteralFor(operand); 1600 1601 auto curr_val = arg_literal.Get<NativeT>(multi_index); 1602 auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val); 1603 1604 arg_literals.push_back(std::move(curr_val_literal)); 1605 } 1606 1607 Literal computed_result = 1608 embedded_evaluator.Evaluate(*computation, arg_literals) 1609 .ConsumeValueOrDie(); 1610 // Clear visit states so that the we can use the evaluate again on 1611 // the same computation. 1612 embedded_evaluator.ResetVisitStates(); 1613 1614 return computed_result.Get<ReturnT>({}); 1615 })); 1616 return std::move(result); 1617 } 1618 1619 Status HandleMap(HloInstruction* map) override { 1620 switch (map->operand(0)->shape().element_type()) { 1621 case PRED: { 1622 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map)); 1623 break; 1624 } 1625 case U8: { 1626 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map)); 1627 break; 1628 } 1629 case U32: { 1630 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map)); 1631 break; 1632 } 1633 case U64: { 1634 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map)); 1635 break; 1636 } 1637 case S8: { 1638 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map)); 1639 break; 1640 } 1641 case S32: { 1642 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map)); 1643 break; 1644 } 1645 case S64: { 1646 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map)); 1647 break; 1648 } 1649 case F16: { 1650 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], 1651 MapImpl<Eigen::half>(map)); 1652 break; 1653 } 1654 case F32: { 1655 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map)); 1656 break; 1657 } 1658 case F64: { 1659 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map)); 1660 break; 1661 } 1662 case C64: { 1663 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map)); 1664 break; 1665 } 1666 case C128: { 1667 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex128>(map)); 1668 break; 1669 } 1670 default: 1671 LOG(FATAL) << "HandleMap: unhandled primitive type for " 1672 "input operand: " 1673 << PrimitiveType_Name( 1674 map->operand(0)->shape().element_type()); 1675 } 1676 1677 return Status::OK(); 1678 } 1679 1680 Status HandleSort(HloInstruction* sort) override { 1681 return UnsupportedTypeError(sort); 1682 } 1683 1684 Status HandleReduce(HloInstruction* hlo) override { 1685 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo); 1686 int64 num_args = reduce->inputs().size(); 1687 bool has_tuple_output = reduce->shape().IsTuple(); 1688 absl::Span<const int64> dimensions(reduce->dimensions()); 1689 HloComputation* function = reduce->to_apply(); 1690 1691 absl::InlinedVector<const Shape*, 1> operand_shapes; 1692 for (const HloInstruction* operand : reduce->operands()) { 1693 operand_shapes.push_back(&operand->shape()); 1694 } 1695 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1696 ShapeInference::InferReduceShape( 1697 operand_shapes, 1698 /*dimensions_to_reduce=*/dimensions, 1699 /*to_apply=*/function->ComputeProgramShape())); 1700 TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) 1701 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) 1702 << " but is inferred to be: " 1703 << ShapeUtil::HumanString(inferred_return_shape); 1704 1705 absl::InlinedVector<const Literal*, 1> arg_literals(num_args); 1706 absl::InlinedVector<const Literal*, 1> init_literals(num_args); 1707 for (int64 i = 0; i < num_args; ++i) { 1708 arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); 1709 VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); 1710 init_literals[i] = 1711 &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); 1712 VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); 1713 TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); 1714 } 1715 1716 // All args and results have the same dimensions, so pick an arbitrary one. 1717 const Shape& arg_shape = arg_literals[0]->shape(); 1718 const Shape& result_shape = reduce->shape().IsTuple() 1719 ? reduce->shape().tuple_shapes(0) 1720 : reduce->shape(); 1721 const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); 1722 std::vector<int64> arg_dim_steps(arg_dimensions.size()); 1723 std::vector<int64> arg_dim_counts(arg_dimensions.size()); 1724 for (const int64 dim : dimensions) { 1725 arg_dim_steps[dim] = 1; 1726 arg_dim_counts[dim] = arg_dimensions[dim]; 1727 } 1728 1729 // Map each dimension in the result to a dimension in arg that isn't 1730 // being reduced. 1731 std::vector<int64> result_to_arg_index; 1732 for (int64 i = 0; i < arg_dimensions.size(); ++i) { 1733 if (arg_dim_steps[i] == 0) { 1734 result_to_arg_index.push_back(i); 1735 } 1736 } 1737 1738 HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); 1739 absl::InlinedVector<Literal, 1> results(num_args); 1740 for (int64 i = 0; i < num_args; ++i) { 1741 results[i] = Literal(result_shape); 1742 } 1743 1744 Status eval_status; 1745 // For each resulting dimension, calculate and assign computed values. 1746 // This is really wasteful when num_args > 1, since we re-run the 1747 // reduction num_args time. The alternative is to teach Populate() about 1748 // tuples, which we should probably do. 1749 absl::InlinedVector<ReturnT, 1> init_scalars(num_args); 1750 for (int i = 0; i < num_args; ++i) { 1751 init_scalars[i] = init_literals[i]->Get<ReturnT>({}); 1752 } 1753 1754 for (int64 input = 0; input < num_args; ++input) { 1755 TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>( 1756 [&](absl::Span<const int64> multi_index) { 1757 if (!eval_status.ok()) { 1758 return init_scalars[input]; 1759 } 1760 absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(), 1761 init_scalars.end()); 1762 std::vector<int64> base(arg_dimensions.size()); 1763 for (int64 i = 0; i < multi_index.size(); ++i) { 1764 base[result_to_arg_index[i]] = multi_index[i]; 1765 } 1766 1767 // When the reduction is addition of floats, accumulate in a double 1768 // for better precision. Also, avoid creating Literals for the 1769 // intermediate results; it's much faster. 1770 if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && 1771 IsScalarAdd(function)) { 1772 CHECK_EQ(num_args, 1); 1773 double computed_result = 0; 1774 auto func = [&](absl::Span<const int64> input_index) { 1775 computed_result += 1776 GetAsDouble<ReturnT>(*arg_literals[0], input_index); 1777 return true; 1778 }; 1779 ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, 1780 arg_dim_counts, arg_dim_steps, func); 1781 return static_cast<ReturnT>(computed_result); 1782 } 1783 auto func = 1784 [&](absl::Span<const int64> input_index) -> StatusOr<bool> { 1785 absl::InlinedVector<ReturnT, 1> arg_values(num_args); 1786 for (int64 i = 0; i < num_args; ++i) { 1787 arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index); 1788 } 1789 1790 // Evaluate computation with specified literal operands. 1791 absl::InlinedVector<Literal, 1> embedded_operands; 1792 for (ReturnT value : result_values) { 1793 embedded_operands.push_back( 1794 LiteralUtil::CreateR0<ReturnT>(value)); 1795 } 1796 for (ReturnT value : arg_values) { 1797 embedded_operands.push_back( 1798 LiteralUtil::CreateR0<ReturnT>(value)); 1799 } 1800 absl::InlinedVector<Literal*, 1> embedded_operands_ptrs( 1801 embedded_operands.size()); 1802 std::transform(embedded_operands.begin(), embedded_operands.end(), 1803 embedded_operands_ptrs.begin(), 1804 [](Literal& literal) { return &literal; }); 1805 1806 TF_ASSIGN_OR_RETURN(Literal computed_result, 1807 embedded_evaluator.Evaluate( 1808 *function, embedded_operands_ptrs)); 1809 // Clear visit states so that we can use the evaluator again on 1810 // the same computation. 1811 embedded_evaluator.ResetVisitStates(); 1812 // Assign computed result to result_val. 1813 if (!has_tuple_output) { 1814 result_values[0] = computed_result.Get<ReturnT>({}); 1815 } else { 1816 for (int64 i = 0; i < num_args; ++i) { 1817 result_values[i] = computed_result.Get<ReturnT>( 1818 /*multi_index=*/{}, /*shape_index=*/{i}); 1819 } 1820 } 1821 return true; 1822 }; 1823 // Computes one element of the result, reducing all dimensions that 1824 // contribute to that element. 1825 eval_status = ShapeUtil::ForEachIndexWithStatus( 1826 arg_shape, base, arg_dim_counts, arg_dim_steps, func); 1827 return result_values[input]; 1828 })); 1829 } 1830 if (!has_tuple_output) { 1831 parent_->evaluated_[reduce] = std::move(results[0]); 1832 } else { 1833 Literal tuple_result(reduce->shape()); 1834 for (int64 i = 0; i < num_args; ++i) { 1835 TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); 1836 } 1837 parent_->evaluated_[reduce] = std::move(tuple_result); 1838 } 1839 return eval_status; 1840 } 1841 1842 bool IsScalarAdd(HloComputation* computation) { 1843 HloInstruction* instruction = computation->root_instruction(); 1844 if (instruction->opcode() == HloOpcode::kAdd && 1845 computation->num_parameters() == 2) { 1846 const HloInstruction* lhs = instruction->operand(0); 1847 const HloInstruction* rhs = instruction->operand(1); 1848 return lhs->opcode() == HloOpcode::kParameter && 1849 ShapeUtil::IsScalar(lhs->shape()) && 1850 rhs->opcode() == HloOpcode::kParameter && 1851 ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; 1852 } 1853 return false; 1854 } 1855 1856 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { 1857 auto operand = select_and_scatter->operand(0); 1858 auto source = select_and_scatter->operand(1); 1859 const Window& window = select_and_scatter->window(); 1860 1861 const Literal& init_literal = 1862 parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); 1863 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); 1864 auto init_scalar = init_literal.Get<ReturnT>({}); 1865 1866 Literal result(select_and_scatter->shape()); 1867 1868 // Initialize result array with the init value. 1869 TF_RETURN_IF_ERROR(result.Populate<ReturnT>( 1870 [&](absl::Span<const int64> output_index) { return init_scalar; })); 1871 1872 std::vector<int64> window_dimension_sizes; 1873 for (const auto& window_dimension : window.dimensions()) { 1874 window_dimension_sizes.push_back(window_dimension.size()); 1875 } 1876 const Shape window_shape = ShapeUtil::MakeShape( 1877 operand->shape().element_type(), window_dimension_sizes); 1878 1879 HloComputation* select = select_and_scatter->select(); 1880 HloComputation* scatter = select_and_scatter->scatter(); 1881 1882 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1883 const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); 1884 1885 int64 rank = operand_literal.shape().rank(); 1886 1887 HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); 1888 DimensionVector source_index(rank, 0); 1889 1890 // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid 1891 // dynamic memory allocations. 1892 auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); 1893 auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); 1894 auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT()); 1895 auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT()); 1896 do { 1897 // For each element in `source`, we place a window in `operand`. For each 1898 // window placement, we iterate inside the window twice: 1899 // 1900 // 1. Find the selected index by applying `select` function to all 1901 // elements. E.g., If the `select` function is GreaterEqual, the first 1902 // iteration through the window finds the biggest value and returns its 1903 // index. 1904 // 1905 // 2. Using the selected index, scatter value from `source` to result. We 1906 // do this by iterating through the window, and compare each index with 1907 // the selected index. 1908 absl::optional<ReturnT> selected_val; 1909 absl::optional<std::vector<int64>> selected_index; 1910 1911 IterateThroughWindow( 1912 window_shape, window, operand_literal.shape(), source_index, 1913 [&](const std::vector<int64>& operand_index) { 1914 auto curr_val = operand_literal.Get<ReturnT>(operand_index); 1915 if (!selected_val) { 1916 selected_val = curr_val; 1917 selected_index = operand_index; 1918 } 1919 curr_val_literal.Set({}, curr_val); 1920 selected_val_literal.Set({}, *selected_val); 1921 Literal computed_result = 1922 embedded_evaluator 1923 .Evaluate(*select, 1924 {&selected_val_literal, &curr_val_literal}) 1925 .ConsumeValueOrDie(); 1926 bool selected = !computed_result.Get<bool>({}); 1927 if (selected) { 1928 selected_val = curr_val; 1929 selected_index = operand_index; 1930 } 1931 embedded_evaluator.ResetVisitStates(); 1932 }); 1933 1934 IterateThroughWindow( 1935 window_shape, window, operand_literal.shape(), source_index, 1936 [&](const std::vector<int64>& operand_index) { 1937 if (std::equal(operand_index.begin(), operand_index.end(), 1938 selected_index->begin())) { 1939 auto source = source_literal.Get<ReturnT>(source_index); 1940 auto scattered = result.Get<ReturnT>(operand_index); 1941 source_literal_scatter.Set({}, source); 1942 scattered_literal.Set({}, scattered); 1943 Literal computed_result = 1944 embedded_evaluator 1945 .Evaluate(*scatter, 1946 {&source_literal_scatter, &scattered_literal}) 1947 .ConsumeValueOrDie(); 1948 result.Set(operand_index, computed_result.Get<ReturnT>({})); 1949 // Clear visit states so that the we can use the evaluator again 1950 // on the same computation. 1951 embedded_evaluator.ResetVisitStates(); 1952 } 1953 }); 1954 } while ( 1955 IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index))); 1956 1957 parent_->evaluated_[select_and_scatter] = std::move(result); 1958 return Status::OK(); 1959 } 1960 1961 Status HandleReduceWindow(HloInstruction* reduce_window) override { 1962 auto operand = reduce_window->operand(0); 1963 const Window& window = reduce_window->window(); 1964 HloComputation* function = reduce_window->to_apply(); 1965 TF_ASSIGN_OR_RETURN( 1966 auto inferred_return_shape, 1967 ShapeInference::InferReduceWindowShape( 1968 /*operand_shape=*/reduce_window->operand(0)->shape(), 1969 /*init_value=*/reduce_window->operand(1)->shape(), window, 1970 /*to_apply_shape=*/function->ComputeProgramShape())); 1971 TF_RET_CHECK( 1972 ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) 1973 << "return shape is set to: " 1974 << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) 1975 << " but is inferred to be: " 1976 << ShapeUtil::HumanStringWithLayout(inferred_return_shape); 1977 1978 const Literal& operand_literal = 1979 parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); 1980 VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); 1981 const Literal& init_literal = 1982 parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); 1983 VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); 1984 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); 1985 auto init_scalar = init_literal.Get<ReturnT>({}); 1986 1987 // Creates a Shape object from window, for iteration below. 1988 std::vector<int64> window_dimension_sizes; 1989 for (const auto& window_dimension : window.dimensions()) { 1990 window_dimension_sizes.push_back(window_dimension.size()); 1991 } 1992 const Shape window_shape = ShapeUtil::MakeShape( 1993 operand->shape().element_type(), window_dimension_sizes); 1994 1995 DimensionVector window_index(window.dimensions_size()); 1996 DimensionVector operand_index(operand_literal.shape().rank()); 1997 1998 HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); 1999 Literal result(reduce_window->shape()); 2000 // For each resulting dimension, calculate and assign computed value. 2001 TF_RETURN_IF_ERROR( 2002 result.Populate<ReturnT>([&](absl::Span<const int64> output_index) { 2003 ReturnT result_val = init_scalar; 2004 2005 std::fill(window_index.begin(), window_index.end(), 0); 2006 std::fill(operand_index.begin(), operand_index.end(), 0); 2007 2008 IterateThroughWindow( 2009 window_shape, window, operand_literal.shape(), output_index, 2010 [&](const std::vector<int64>& operand_index) { 2011 auto curr_val = operand_literal.Get<ReturnT>(operand_index); 2012 2013 // Evaluate computation with specified literal operands. 2014 const auto curr_val_literal = 2015 LiteralUtil::CreateR0<ReturnT>(curr_val); 2016 const auto result_val_literal = 2017 LiteralUtil::CreateR0<ReturnT>(result_val); 2018 Literal computed_result = 2019 embedded_evaluator 2020 .Evaluate(*function, 2021 {&result_val_literal, &curr_val_literal}) 2022 .ConsumeValueOrDie(); 2023 2024 // Clear visit states so that the we can use the evaluate again 2025 // on the same computation. 2026 embedded_evaluator.ResetVisitStates(); 2027 2028 result_val = computed_result.Get<ReturnT>({}); 2029 }); 2030 2031 return result_val; 2032 })); 2033 2034 parent_->evaluated_[reduce_window] = std::move(result); 2035 return Status::OK(); 2036 } 2037 2038 // Reshapes the scatter indices input to have a trailing degenerate `1` 2039 // dimension if necessary. Hands over the ownership of the newly created 2040 // literal (if there is one) to `reshaped_indices`. 2041 StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices( 2042 int64 index_vector_dim, const Literal& indices, 2043 Literal* reshaped_indices) { 2044 if (indices.shape().dimensions_size() != index_vector_dim) { 2045 return std::cref(indices); 2046 } 2047 2048 std::vector<int64> new_shape(indices.shape().dimensions().begin(), 2049 indices.shape().dimensions().end()); 2050 new_shape.push_back(1); 2051 TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); 2052 return std::cref(*reshaped_indices); 2053 } 2054 2055 // Returns an ShapeUtil::IndexIterationSpace that iterates over the update 2056 // scatter dimensions while keeping the rest of the update dimensions clamped 2057 // to 0. 2058 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( 2059 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { 2060 int64 updates_rank = updates_shape.dimensions_size(); 2061 std::vector<int64> index_base(updates_rank, 0); 2062 std::vector<int64> index_count(updates_rank, 1); 2063 for (int64 i = 0; i < updates_rank; i++) { 2064 bool is_update_scatter_dim = 2065 !absl::c_binary_search(dim_numbers.update_window_dims(), i); 2066 if (is_update_scatter_dim) { 2067 index_count[i] = updates_shape.dimensions(i); 2068 } 2069 } 2070 return {std::move(index_base), std::move(index_count), 2071 std::vector<int64>(updates_rank, 1)}; 2072 } 2073 2074 // Return an ShapeUtil::IndexIterationSpace that iterates over the update 2075 // window dimensions while keeping the rest of the update dimensions clamped 2076 // to 0. 2077 ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( 2078 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { 2079 int64 updates_rank = updates_shape.dimensions_size(); 2080 std::vector<int64> index_base(updates_rank, 0); 2081 std::vector<int64> index_count(updates_rank, 1); 2082 for (int64 i = 0; i < updates_rank; i++) { 2083 bool is_update_window_dim = 2084 absl::c_binary_search(dim_numbers.update_window_dims(), i); 2085 if (is_update_window_dim) { 2086 index_count[i] = updates_shape.dimensions(i); 2087 } 2088 } 2089 return {std::move(index_base), std::move(index_count), 2090 std::vector<int64>(updates_rank, 1)}; 2091 } 2092 2093 // This functor computes the contribution of scatter_indices to an input index 2094 // corresponding to an update index. That is, given an update index I, it 2095 // picks out the scatter indices in I and uses them to look up a scatter 2096 // index, S, from the scatter indices tensor, and expands S into the input 2097 // space according to scatter_dims_to_operand_dims. 2098 // 2099 // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex 2100 // that does the corresponding function for Gather. 2101 class UpdateScatterIndexToInputIndex { 2102 public: 2103 // The constructor does some setup work that is amortized across all 2104 // iterations. 2105 explicit UpdateScatterIndexToInputIndex( 2106 const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, 2107 const Shape& updates_shape, const Literal* scatter_indices) 2108 : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { 2109 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { 2110 update_dim_is_scatter_dims_.push_back( 2111 !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); 2112 } 2113 2114 for (int64 i = 0; i < input_shape.dimensions_size(); i++) { 2115 int64 index_of_input_dim_in_index_vector = 2116 FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); 2117 if (index_of_input_dim_in_index_vector == 2118 dim_numbers_.scatter_dims_to_operand_dims_size()) { 2119 input_dim_value_to_index_vector_.push_back(-1); 2120 } else { 2121 input_dim_value_to_index_vector_.push_back( 2122 index_of_input_dim_in_index_vector); 2123 } 2124 } 2125 2126 index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); 2127 input_index_.resize(input_shape.dimensions_size()); 2128 int64 index_vector_size = 2129 scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); 2130 index_vector_.resize(index_vector_size); 2131 } 2132 2133 // Returns the contribution of scatter_indices to the input index 2134 // corresponding to update_index. See scatter_inner_loop_body. 2135 // 2136 // This is conceptually a stateless transformation from update_index to the 2137 // scatter input index, but: 2138 // 2139 // - Instead of allocating memory to represent the scatter input index on 2140 // every invocation we reuse the same storage for the result 2141 // (input_index_), mutating it in place. 2142 // - Instead of allocating buffers for temporary values like 2143 // index_vector_index_ and index_vector on every invocation, we reuse the 2144 // same storage for all invocations. 2145 // 2146 // This returns a Span into memory owned by the class. 2147 StatusOr<absl::Span<const int64>> operator()( 2148 absl::Span<const int64> update_index) { 2149 PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); 2150 TF_RETURN_IF_ERROR(FetchIndexVector()); 2151 PropagateIndexVectorToInputIndex(); 2152 return absl::Span<const int64>(input_index_); 2153 } 2154 2155 private: 2156 // Propagates the scatter index dimensions from the update index into 2157 // index_vector_index_ by mutating index_vector_index_ in place. Does not 2158 // update the dim_numbers.index_vector_dim() dimension -- that's the 2159 // dimension we iterate over in FetchIndexVector. 2160 void PropagateUpdateIndexScatterDimsToIndexVectorIndex( 2161 absl::Span<const int64> update_index) { 2162 int64 index_vector_index_i = 0; 2163 for (int64 i = 0, e = update_index.size(); i < e; i++) { 2164 if (!update_dim_is_scatter_dims_[i]) { 2165 continue; 2166 } 2167 2168 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { 2169 index_vector_index_i++; 2170 } 2171 2172 index_vector_index_[index_vector_index_i++] = update_index[i]; 2173 } 2174 } 2175 2176 // Populates index_vector_ by iterating over scatter_indices_ according to 2177 // index_vector_index_. 2178 Status FetchIndexVector() { 2179 int64 index_vector_dim = dim_numbers_.index_vector_dim(); 2180 for (int64 i = 0, e = index_vector_.size(); i < e; i++) { 2181 index_vector_index_[index_vector_dim] = i; 2182 TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( 2183 index_vector_index_)); 2184 } 2185 return Status::OK(); 2186 } 2187 2188 // Populates input_index_. 2189 void PropagateIndexVectorToInputIndex() { 2190 for (int64 i = 0, e = input_index_.size(); i < e; i++) { 2191 if (input_dim_value_to_index_vector_[i] != -1) { 2192 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; 2193 } 2194 2195 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] 2196 // remains 0, as set by the constructor. 2197 } 2198 } 2199 2200 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i 2201 // of the input index from the index vector. See 2202 // PropagateIndexVectorToInputIndex. 2203 std::vector<int64> input_dim_value_to_index_vector_; 2204 2205 // update_dim_is_scatter_dims_[i] is true iff the update index i is a 2206 // scatter dimension. 2207 std::vector<bool> update_dim_is_scatter_dims_; 2208 2209 // The buffer into which we construct an index into scatter_indices_ to 2210 // fetch the index vector. 2211 std::vector<int64> index_vector_index_; 2212 2213 // The index vector fetched from scatter_indices_. 2214 std::vector<int64> index_vector_; 2215 2216 // The result computed by this functor. operator() returns a Span 2217 // into this vector. 2218 std::vector<int64> input_index_; 2219 2220 const ScatterDimensionNumbers& dim_numbers_; 2221 const Literal& scatter_indices_; 2222 }; 2223 2224 // This functor computes the contribution of the window indices in an update 2225 // index to an input index. That is, given an update index I it picks out the 2226 // update window indices in I and expands it into a window index into the 2227 // input shape. 2228 // 2229 // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex 2230 // that does the corresponding function for Gather. 2231 class UpdateWindowIndexToInputIndex { 2232 public: 2233 // The constructor does some setup work that is amortized across all 2234 // iterations. 2235 explicit UpdateWindowIndexToInputIndex( 2236 const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, 2237 const Shape& updates_shape) { 2238 std::vector<int64> window_index_to_update_index; 2239 int64 update_index_count = 0; 2240 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { 2241 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { 2242 window_index_to_update_index.push_back(update_index_count++); 2243 } else { 2244 update_index_count++; 2245 } 2246 } 2247 2248 int64 window_dim_count = 0; 2249 for (int64 i = 0; i < input_shape.dimensions_size(); i++) { 2250 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { 2251 input_dim_value_to_update_index_.push_back(-1); 2252 } else { 2253 input_dim_value_to_update_index_.push_back( 2254 window_index_to_update_index[window_dim_count++]); 2255 } 2256 } 2257 2258 input_index_.resize(input_shape.dimensions_size()); 2259 } 2260 2261 // Returns the contribution of the window indices to the input index 2262 // corresponding to update_index. See scatter_inner_loop_body. 2263 // 2264 // This is conceptually a stateless transformation from update_index to the 2265 // window input index, but instead of allocating memory to represent the 2266 // scatter input index on every invocation we reuse the same storage for the 2267 // result (input_index_), mutating it in place. 2268 // 2269 // This returns a Span into memory owned by the class. 2270 StatusOr<absl::Span<const int64>> operator()( 2271 absl::Span<const int64> update_index) { 2272 PropagateUpdateIndexWindowDimsToInputIndex(update_index); 2273 return absl::Span<const int64>(input_index_); 2274 } 2275 2276 // Returns for a given 'input_dim' the corresponding update dimension index, 2277 // or -1 if 'input_dim' is an elided window dimension. 2278 int64 input_dim_value_to_update_index(int64 input_dim) { 2279 return input_dim_value_to_update_index_[input_dim]; 2280 } 2281 2282 private: 2283 // Propagates window dimensions from the update index to input_index_ by 2284 // mutating input_index_ in place. 2285 void PropagateUpdateIndexWindowDimsToInputIndex( 2286 absl::Span<const int64> update_index) { 2287 for (int64 i = 0, e = input_index_.size(); i < e; i++) { 2288 if (input_dim_value_to_update_index_[i] != -1) { 2289 input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; 2290 } 2291 2292 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] 2293 // remains 0, as set by the constructor. 2294 } 2295 } 2296 2297 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i 2298 // of the input index from the update index. See 2299 // PropagateUpdateIndexWindowDimsToInputIndex. 2300 std::vector<int64> input_dim_value_to_update_index_; 2301 2302 // The result computed by this functor. operator() returns a Span 2303 // into this vector. 2304 std::vector<int64> input_index_; 2305 }; 2306 2307 Status HandleScatter(HloInstruction* scatter) override { 2308 const ScatterDimensionNumbers& dim_numbers = 2309 scatter->scatter_dimension_numbers(); 2310 const Literal& operand = 2311 parent_->GetEvaluatedLiteralFor(scatter->operand(0)); 2312 Literal reshaped_scatter_indices; 2313 TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, 2314 ReshapedScatterIndices(dim_numbers.index_vector_dim(), 2315 parent_->GetEvaluatedLiteralFor( 2316 scatter->operand(1)), 2317 &reshaped_scatter_indices)); 2318 const Literal& updates = 2319 parent_->GetEvaluatedLiteralFor(scatter->operand(2)); 2320 const Shape& updates_shape = updates.shape(); 2321 const Shape& operand_shape = operand.shape(); 2322 2323 ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = 2324 IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); 2325 ShapeUtil::IndexIterationSpace window_indices_iteration_space = 2326 IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); 2327 2328 std::vector<int64> input_index(operand_shape.dimensions_size()); 2329 std::vector<int64> update_index(updates_shape.dimensions_size()); 2330 std::vector<int64> input_scatter_index_clamped( 2331 operand_shape.dimensions_size()); 2332 2333 UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( 2334 &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, 2335 updates_shape, &scatter_indices); 2336 UpdateWindowIndexToInputIndex update_window_index_to_input_index( 2337 scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, 2338 updates_shape); 2339 2340 // Initialize the result with the operand. This makes it easier to handle 2341 // the updates even when the indices are repeated. 2342 Literal result = operand.Clone(); 2343 HloEvaluator embedded_evaluator; 2344 auto scatter_inner_loop_body = 2345 [&](absl::Span<const int64> update_window_index, 2346 absl::Span<const int64> input_scatter_index, 2347 absl::Span<const int64> update_scatter_index) -> StatusOr<bool> { 2348 TF_ASSIGN_OR_RETURN( 2349 absl::Span<const int64> input_window_index, 2350 update_window_index_to_input_index(update_window_index)); 2351 for (int i = 0, e = update_index.size(); i < e; i++) { 2352 update_index[i] = update_scatter_index[i] + update_window_index[i]; 2353 DCHECK_LT(update_index[i], updates_shape.dimensions(i)); 2354 } 2355 for (int i = 0, e = input_scatter_index.size(); i < e; i++) { 2356 int64 update_dim = 2357 update_window_index_to_input_index.input_dim_value_to_update_index( 2358 i); 2359 // If 'update_dim' is -1, it means 'i' is an elided window dim. This 2360 // means we set the iteration index to 0, so for the purpose of the 2361 // following calculations we can consider the update dimension size to 2362 // be 1. 2363 int64 update_dim_size = 2364 update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); 2365 // If any part of the update region is out-of-bounds, then do not 2366 // perform any update on the input. 2367 if ((input_scatter_index[i] < 0) || 2368 (input_scatter_index[i] > 2369 operand_shape.dimensions(i) - update_dim_size)) { 2370 return true; 2371 } 2372 } 2373 for (int i = 0, e = input_index.size(); i < e; i++) { 2374 input_index[i] = input_scatter_index[i] + input_window_index[i]; 2375 } 2376 2377 auto result_value_literal = 2378 LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index)); 2379 auto update_value_literal = 2380 LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index)); 2381 Literal updated_result = 2382 embedded_evaluator 2383 .Evaluate(*scatter->to_apply(), 2384 {&result_value_literal, &update_value_literal}) 2385 .ConsumeValueOrDie(); 2386 // Clear visit states so that the we can use the evaluate again on the 2387 // same computation. 2388 embedded_evaluator.ResetVisitStates(); 2389 result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({})); 2390 return true; 2391 }; 2392 2393 auto scatter_outer_loop_body = 2394 [&](absl::Span<const int64> update_scatter_index) -> StatusOr<bool> { 2395 TF_ASSIGN_OR_RETURN( 2396 absl::Span<const int64> input_scatter_index, 2397 update_scatter_index_to_input_index(update_scatter_index)); 2398 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( 2399 updates_shape, window_indices_iteration_space, 2400 [&](absl::Span<const int64> update_window_index) { 2401 return scatter_inner_loop_body( 2402 update_window_index, input_scatter_index, update_scatter_index); 2403 })); 2404 return true; 2405 }; 2406 2407 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( 2408 updates_shape, scatter_indices_iteration_space, 2409 scatter_outer_loop_body)); 2410 parent_->evaluated_[scatter] = std::move(result); 2411 return Status::OK(); 2412 } 2413 2414 Status HandleSlice(HloInstruction* slice) override { 2415 auto operand = slice->operand(0); 2416 const Shape& shape = slice->shape(); 2417 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 2418 ShapeInference::InferSliceShape( 2419 operand->shape(), slice->slice_starts(), 2420 slice->slice_limits(), slice->slice_strides())); 2421 TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) 2422 << "return shape set to: " << ShapeUtil::HumanString(shape) 2423 << " but is inferred to be: " 2424 << ShapeUtil::HumanString(inferred_return_shape); 2425 2426 const int64 rank = operand->shape().rank(); 2427 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 2428 auto func = [&](absl::Span<const int64> out_index) { 2429 DimensionVector operand_index(rank); 2430 for (int64 i = 0; i < rank; ++i) { 2431 operand_index[i] = 2432 slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); 2433 } 2434 return operand_literal.Get<ReturnT>(operand_index); 2435 }; 2436 2437 Literal result(shape); 2438 TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func)); 2439 parent_->evaluated_[slice] = std::move(result); 2440 return Status::OK(); 2441 } 2442 2443 // Enable CLZ only for int32, uint32, int64 and uint64. 2444 template < 2445 typename NativeT, 2446 typename std::enable_if< 2447 (std::is_floating_point<NativeT>::value || 2448 std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) && 2449 !(std::is_same<NativeT, uint32>::value || 2450 std::is_same<NativeT, int32>::value || 2451 std::is_same<NativeT, int64>::value || 2452 std::is_same<NativeT, uint64>::value)>::type* = nullptr> 2453 Status HandleClz(HloInstruction* clz) { 2454 return UnsupportedTypeError(clz); 2455 } 2456 2457 template <typename NativeT, 2458 typename std::enable_if< 2459 std::is_same<NativeT, uint32>::value || 2460 std::is_same<NativeT, int32>::value>::type* = nullptr> 2461 Status HandleClz(HloInstruction* clz) { 2462 TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], 2463 ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { 2464 return 31 - tensorflow::Log2Floor(elem_operand); 2465 })); 2466 return Status::OK(); 2467 } 2468 2469 template <typename NativeT, 2470 typename std::enable_if< 2471 std::is_same<NativeT, uint64>::value || 2472 std::is_same<NativeT, int64>::value>::type* = nullptr> 2473 Status HandleClz(HloInstruction* clz) { 2474 TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], 2475 ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { 2476 return 63 - tensorflow::Log2Floor64(elem_operand); 2477 })); 2478 return Status::OK(); 2479 } 2480 2481 Status HandleClz(HloInstruction* clz) override { 2482 return HandleClz<ElementwiseT>(clz); 2483 } 2484 2485 template <typename NativeT, typename std::enable_if<std::is_floating_point< 2486 NativeT>::value>::type* = nullptr> 2487 Status HandleSin(HloInstruction* sin) { 2488 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], 2489 ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { 2490 return std::sin(elem_operand); 2491 })); 2492 return Status::OK(); 2493 } 2494 2495 template < 2496 typename NativeT, 2497 typename std::enable_if<std::is_integral<NativeT>::value || 2498 is_complex_t<NativeT>::value>::type* = nullptr> 2499 Status HandleSin(HloInstruction* sin) { 2500 return UnsupportedTypeError(sin); 2501 } 2502 2503 Status HandleSin(HloInstruction* sin) override { 2504 return HandleSin<ElementwiseT>(sin); 2505 } 2506 2507 template <typename NativeT, typename std::enable_if<std::is_floating_point< 2508 NativeT>::value>::type* = nullptr> 2509 Status HandleCos(HloInstruction* cos) { 2510 TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], 2511 ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { 2512 return std::cos(elem_operand); 2513 })); 2514 return Status::OK(); 2515 } 2516 2517 template < 2518 typename NativeT, 2519 typename std::enable_if<std::is_integral<NativeT>::value || 2520 is_complex_t<NativeT>::value>::type* = nullptr> 2521 Status HandleCos(HloInstruction* cos) { 2522 return UnsupportedTypeError(cos); 2523 } 2524 2525 Status HandleCos(HloInstruction* cos) override { 2526 return HandleCos<ElementwiseT>(cos); 2527 } 2528 2529 template <typename NativeT, typename std::enable_if<std::is_same< 2530 float, NativeT>::value>::type* = nullptr> 2531 Status HandleReducePrecision(HloInstruction* reduce_precision) { 2532 TF_ASSIGN_OR_RETURN( 2533 parent_->evaluated_[reduce_precision], 2534 ElementWiseUnaryOp(reduce_precision, [reduce_precision]( 2535 ElementwiseT elem) { 2536 uint32_t value_as_int = absl::bit_cast<uint32_t>(elem); 2537 const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); 2538 const uint32_t exponent_bits = reduce_precision->exponent_bits(); 2539 2540 // Code is based on the CPU/GPU implementation in LLVM-emitting code. 2541 // 2542 // Bits in float type: 2543 // mantissa : bits [0:22] 2544 // exponent : bits [23:30] 2545 // sign : bits [31] 2546 if (mantissa_bits < 23) { 2547 const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); 2548 2549 // Compute rounding bias for round-to-nearest with ties to even. 2550 // This is equal to a base value of 0111... plus one bit if the last 2551 // remaining mantissa bit is 1. 2552 const uint32_t base_rounding_bias = 2553 (last_mantissa_bit_mask >> 1) - 1; 2554 const uint32_t x_last_mantissa_bit = 2555 (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); 2556 const uint32_t x_rounding_bias = 2557 x_last_mantissa_bit + base_rounding_bias; 2558 2559 // Add rounding bias, and mask out truncated bits. Note that the 2560 // case where adding the rounding bias overflows into the exponent 2561 // bits is correct; the non-masked mantissa bits will all be zero, 2562 // and the exponent will be incremented by one. 2563 const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); 2564 value_as_int = value_as_int + x_rounding_bias; 2565 value_as_int = value_as_int & truncation_mask; 2566 } 2567 if (exponent_bits < 8) { 2568 // Masks for f32 values. 2569 const uint32_t f32_sign_bit_mask = 1u << 31; 2570 const uint32_t f32_exp_bits_mask = 0xffu << 23; 2571 2572 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the 2573 // most- significant bit -- is equal to 1.0f for all exponent sizes. 2574 // Adding 2^(n-1)-1 to this gives us the highest non-infinite 2575 // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from 2576 // this gives us the lowest' exponent (corresponding to 0.0f). 2577 // 2578 // Thus, the f32 exponent corresponding to the highest non-infinite 2579 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 2580 // exponent corresponding to the lowest exponent for a bit size of n 2581 // is (2^7-1) - 2^(n-1)-1. 2582 // 2583 // Note that we have already checked that exponents_bits >= 1. 2584 const uint32_t f32_exponent_bias = (1 << 7) - 1; 2585 const uint32_t reduced_exponent_bias = 2586 (1 << (exponent_bits - 1)) - 1; 2587 const uint32_t reduced_max_exponent = 2588 f32_exponent_bias + reduced_exponent_bias; 2589 const uint32_t reduced_min_exponent = 2590 f32_exponent_bias - reduced_exponent_bias; 2591 2592 // Do we overflow or underflow? 2593 const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; 2594 const bool x_overflows = x_exponent > (reduced_max_exponent << 23); 2595 const bool x_underflows = 2596 x_exponent <= (reduced_min_exponent << 23); 2597 2598 // Compute appropriately-signed values of zero and infinity. 2599 const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; 2600 const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; 2601 2602 // Force to zero or infinity if overflow or underflow. (Note that 2603 // this truncates all denormal values to zero, rather than rounding 2604 // them.) 2605 value_as_int = x_overflows ? x_signed_inf : value_as_int; 2606 value_as_int = x_underflows ? x_signed_zero : value_as_int; 2607 } 2608 2609 float reduced_result = absl::bit_cast<float>(value_as_int); 2610 if (std::isnan(elem)) { 2611 reduced_result = mantissa_bits > 0 2612 ? elem 2613 : std::numeric_limits<float>::infinity(); 2614 } 2615 return reduced_result; 2616 })); 2617 return Status::OK(); 2618 } 2619 2620 template <typename NativeT, typename std::enable_if<std::is_same< 2621 double, NativeT>::value>::type* = nullptr> 2622 Status HandleReducePrecision(HloInstruction* reduce_precision) { 2623 return InvalidArgument("Double is not supported for reduce precision"); 2624 } 2625 2626 template < 2627 typename NativeT, 2628 typename std::enable_if<std::is_integral<NativeT>::value || 2629 is_complex_t<NativeT>::value>::type* = nullptr> 2630 Status HandleReducePrecision(HloInstruction* reduce_precision) { 2631 return UnsupportedTypeError(reduce_precision); 2632 } 2633 2634 Status HandleReducePrecision(HloInstruction* reduce_precision) override { 2635 return HandleReducePrecision<ElementwiseT>(reduce_precision); 2636 } 2637 2638 template < 2639 typename NativeT, 2640 typename std::enable_if< 2641 std::is_same<NativeT, bfloat16>::value || 2642 std::is_same<NativeT, Eigen::half>::value || 2643 std::is_integral<NativeT>::value || is_complex_t<NativeT>::value || 2644 std::is_floating_point<NativeT>::value>::type* = nullptr> 2645 Status HandleIota(HloInstruction* instruction) { 2646 auto* iota = Cast<HloIotaInstruction>(instruction); 2647 const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); 2648 // Avoid using std::vector since std::vector<bool> does not convert to 2649 // absl::Span<bool>. 2650 absl::InlinedVector<NativeT, 1> data(iota_size); 2651 // We don't use std::iota for two reasons: 2652 // 2653 // (1) std:iota does not support bfloat16 and float16. 2654 // 2655 // (2) std::iota saturates for floating point types when the value is not 2656 // representable, but the definition of HLO iota is the value as a 2657 // 64-bit integer cast to the native type. 2658 for (int64 i = 0; i < iota_size; ++i) { 2659 // static_cast is required for Eigen::half (F16). 2660 data[i] = static_cast<NativeT>(i); 2661 } 2662 auto result = LiteralUtil::CreateR1<NativeT>(data); 2663 2664 if (iota->shape().rank() > 1) { 2665 TF_ASSIGN_OR_RETURN( 2666 parent_->evaluated_[iota], 2667 result.Broadcast(iota->shape(), {iota->iota_dimension()})); 2668 } else { 2669 TF_RET_CHECK(iota->shape().rank() == 1); 2670 parent_->evaluated_[iota] = std::move(result); 2671 } 2672 2673 return Status::OK(); 2674 } 2675 template < 2676 typename NativeT, 2677 typename std::enable_if< 2678 !(std::is_same<NativeT, bfloat16>::value || 2679 std::is_same<NativeT, Eigen::half>::value || 2680 std::is_integral<NativeT>::value || is_complex_t<NativeT>::value || 2681 std::is_floating_point<NativeT>::value)>::type* = nullptr> 2682 Status HandleIota(HloInstruction* iota) { 2683 return UnsupportedTypeError(iota); 2684 } 2685 Status HandleIota(HloInstruction* iota) override { 2686 return HandleIota<ReturnT>(iota); 2687 } 2688 2689 template <typename NativeT, 2690 typename std::enable_if< 2691 !(std::is_integral<NativeT>::value || 2692 std::is_floating_point<NativeT>::value)>::type* = nullptr> 2693 Status HandleRng(HloInstruction* random) { 2694 return UnsupportedTypeError(random); 2695 } 2696 template <typename NativeT, 2697 typename std::enable_if< 2698 (std::is_floating_point<NativeT>::value)>::type* = nullptr> 2699 Status HandleRng(HloInstruction* random) { 2700 RandomDistribution distribution = random->random_distribution(); 2701 const auto result_shape = random->shape(); 2702 Literal result(result_shape); 2703 2704 switch (distribution) { 2705 case RNG_UNIFORM: { 2706 const Literal& low = 2707 parent_->GetEvaluatedLiteralFor(random->operand(0)); 2708 const Literal& high = 2709 parent_->GetEvaluatedLiteralFor(random->operand(1)); 2710 2711 // std::uniform_real_distribution(a, b) can sometimes return a value 2712 // equal to b. Unclear if this is a spec bug or an implementation bug 2713 // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open 2714 // interval, so we have to re-sample if we get `b` out. 2715 // 2716 // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 2717 // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 2718 // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 2719 auto low_val = low.Get<NativeT>({}); 2720 auto high_val = high.Get<NativeT>({}); 2721 std::uniform_real_distribution<NativeT> generator(low_val, high_val); 2722 TF_RETURN_IF_ERROR( 2723 result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) { 2724 while (true) { 2725 NativeT v = generator(parent_->engine_); 2726 if (v != high_val) { 2727 return v; 2728 } 2729 } 2730 })); 2731 break; 2732 } 2733 case RNG_NORMAL: { 2734 const Literal& mean = 2735 parent_->GetEvaluatedLiteralFor(random->operand(0)); 2736 const Literal& stddev = 2737 parent_->GetEvaluatedLiteralFor(random->operand(1)); 2738 2739 std::normal_distribution<NativeT> generator(mean.Get<NativeT>({}), 2740 stddev.Get<NativeT>({})); 2741 2742 TF_RETURN_IF_ERROR( 2743 result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) { 2744 return generator(parent_->engine_); 2745 })); 2746 break; 2747 } 2748 default: 2749 return UnimplementedStrCat("The distribution ", 2750 RandomDistribution_Name(distribution), 2751 " is not implemented."); 2752 } 2753 parent_->evaluated_[random] = std::move(result); 2754 return Status::OK(); 2755 } 2756 template <typename NativeT, 2757 typename std::enable_if<(std::is_integral<NativeT>::value)>::type* = 2758 nullptr> 2759 Status HandleRng(HloInstruction* random) { 2760 RandomDistribution distribution = random->random_distribution(); 2761 const auto result_shape = random->shape(); 2762 Literal result(result_shape); 2763 2764 switch (distribution) { 2765 case RNG_UNIFORM: { 2766 const Literal& low = 2767 parent_->GetEvaluatedLiteralFor(random->operand(0)); 2768 const Literal& high = 2769 parent_->GetEvaluatedLiteralFor(random->operand(1)); 2770 2771 // Note std::uniform_int_distribution assumes interval is closed, i.e., 2772 // [low, high], but we want [low, high) instead. Hence high-1 is used as 2773 // the upper range. 2774 std::uniform_int_distribution<int64> generator( 2775 low.Get<NativeT>({}), high.Get<NativeT>({}) - 1); 2776 2777 TF_RETURN_IF_ERROR( 2778 result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) { 2779 return static_cast<NativeT>(generator(parent_->engine_)); 2780 })); 2781 break; 2782 } 2783 case RNG_NORMAL: { 2784 return Unimplemented( 2785 "Normal distribution is not supported for integral types."); 2786 } 2787 default: 2788 return UnimplementedStrCat("The distribution ", 2789 RandomDistribution_Name(distribution), 2790 " is not implemented."); 2791 } 2792 parent_->evaluated_[random] = std::move(result); 2793 return Status::OK(); 2794 } 2795 Status HandleRng(HloInstruction* random) override { 2796 return HandleRng<ReturnT>(random); 2797 } 2798 2799 private: 2800 // Creates a vector of multipliers which can be used to create a linear index 2801 // into shape. 2802 // 2803 // Given the multidimensional index {i1, ..., iN} and 2804 // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply 2805 // 2806 // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. 2807 // 2808 // This lets you calculate LI given the multidimensional indices in any order. 2809 static DimensionVector MakeDimMultipliers(const Shape& shape) { 2810 DimensionVector v(shape.rank()); 2811 int64 scale = 1; 2812 for (auto dim : LayoutUtil::MinorToMajor(shape)) { 2813 v[dim] = scale; 2814 scale *= shape.dimensions(dim); 2815 } 2816 return v; 2817 } 2818 2819 // For one particular placement of a window in a base shape (the placement is 2820 // represented as `window_count_index`), iterates inside the window. 2821 // Translates the window index into base index. If the base index is within 2822 // bound, call `f` with the base index. 2823 static void IterateThroughWindow( 2824 const Shape& window_shape, const Window& window, const Shape& base_shape, 2825 const absl::Span<const int64>& window_count_index, 2826 const std::function<void(const std::vector<int64>&)>& f) { 2827 const int64 rank = base_shape.rank(); 2828 DimensionVector window_index(rank); 2829 std::fill(window_index.begin(), window_index.end(), 0); 2830 do { 2831 std::vector<int64> base_index(rank); 2832 bool out_of_bound = false; 2833 for (int64 i = 0; i < rank; ++i) { 2834 base_index[i] = 2835 window_count_index[i] * window.dimensions(i).stride() + 2836 window_index[i] * window.dimensions(i).window_dilation() - 2837 window.dimensions(i).padding_low(); 2838 // We are not in the base area if the dilation placed us out of bounds. 2839 if (base_index[i] % window.dimensions(i).base_dilation() != 0) { 2840 out_of_bound = true; 2841 break; 2842 } 2843 // Apply the dilation to the base area. 2844 base_index[i] /= window.dimensions(i).base_dilation(); 2845 if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { 2846 out_of_bound = true; 2847 break; 2848 } 2849 } 2850 if (!out_of_bound) { 2851 f(base_index); 2852 } 2853 } while ( 2854 IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index))); 2855 } 2856 2857 template <typename IndexT> 2858 StatusOr<Literal> DynamicSlice( 2859 const Literal& operand_literal, 2860 absl::Span<HloInstruction* const> start_indices, 2861 const Shape& result_shape) { 2862 std::vector<int64> start; 2863 2864 for (HloInstruction* index : start_indices) { 2865 start.push_back( 2866 parent_->GetEvaluatedLiteralFor(index).GetFirstElement<IndexT>()); 2867 } 2868 2869 // Clamp the start indices so the slice is in-bounds w.r.t the operand. 2870 for (int64 i = 0; i < start.size(); ++i) { 2871 start[i] = std::min<int64>( 2872 std::max(int64{0}, start[i]), 2873 operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); 2874 } 2875 2876 std::vector<int64> operand_indices(start.size()); 2877 Literal result(result_shape); 2878 TF_RETURN_IF_ERROR( 2879 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 2880 for (int64 i = 0; i < operand_indices.size(); ++i) { 2881 CHECK_GE(multi_index[i] + start[i], 0); 2882 operand_indices[i] = multi_index[i] + start[i]; 2883 } 2884 2885 auto result = operand_literal.Get<ReturnT>(operand_indices); 2886 return result; 2887 })); 2888 2889 return std::move(result); 2890 } 2891 2892 template <typename IndexT> 2893 StatusOr<Literal> DynamicUpdateSlice( 2894 const Literal& operand_literal, const Literal& update_literal, 2895 absl::Span<HloInstruction* const> start_indices) { 2896 auto result = operand_literal.Clone(); 2897 const auto rank = result.shape().rank(); 2898 std::vector<int64> start; 2899 for (HloInstruction* index : start_indices) { 2900 start.push_back( 2901 parent_->GetEvaluatedLiteralFor(index).GetFirstElement<IndexT>()); 2902 } 2903 2904 // Clamp the update start indices so the slice is in-bounds w.r.t the 2905 // operand. 2906 for (int64 i = 0; i < rank; ++i) { 2907 start[i] = std::min<int64>( 2908 std::max<int64>(0, start[i]), 2909 result.shape().dimensions(i) - update_literal.shape().dimensions(i)); 2910 } 2911 std::vector<int64> result_index(rank, 0); 2912 2913 auto func = [&](absl::Span<const int64> update_index) { 2914 std::transform(update_index.begin(), update_index.end(), start.begin(), 2915 result_index.begin(), std::plus<int64>()); 2916 result.Set<ReturnT>(result_index, 2917 update_literal.Get<ReturnT>(update_index)); 2918 return true; 2919 }; 2920 2921 std::vector<int64> base(update_literal.shape().dimensions_size(), 0); 2922 std::vector<int64> step(update_literal.shape().dimensions_size(), 1); 2923 ShapeUtil::ForEachIndex(update_literal.shape(), base, 2924 AsInt64Slice(update_literal.shape().dimensions()), 2925 step, func); 2926 2927 return std::move(result); 2928 } 2929 2930 StatusOr<Literal> ElementWiseUnaryOp( 2931 HloInstruction* instruction, 2932 const std::function<ElementwiseT(ElementwiseT)>& unary_op) { 2933 const Literal& operand_literal = 2934 parent_->GetEvaluatedLiteralFor(instruction->operand(0)); 2935 TF_ASSIGN_OR_RETURN( 2936 auto result_literal, 2937 (HloEvaluator::ElementWiseUnaryOpImpl<ReturnT, ReturnT>( 2938 instruction, ConvertUnaryFunction(unary_op), operand_literal))); 2939 2940 return std::move(result_literal); 2941 } 2942 2943 StatusOr<Literal> ElementWiseBinaryOp( 2944 HloInstruction* instruction, 2945 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& 2946 binary_op) { 2947 const auto shape = instruction->shape(); 2948 const auto* lhs = instruction->operand(0); 2949 const auto* rhs = instruction->operand(1); 2950 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape())); 2951 TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); 2952 2953 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 2954 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 2955 2956 Literal result(shape); 2957 2958 TF_RETURN_IF_ERROR( 2959 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 2960 return ConvertBinaryFunction(binary_op)( 2961 lhs_literal.Get<ReturnT>(multi_index), 2962 rhs_literal.Get<ReturnT>(multi_index)); 2963 })); 2964 return std::move(result); 2965 } 2966 2967 template <typename LhsType, typename RhsType, typename EhsType> 2968 StatusOr<Literal> ElementwiseTernaryOp( 2969 HloInstruction* instruction, 2970 const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { 2971 const auto shape = instruction->shape(); 2972 const auto* lhs = instruction->operand(0); 2973 const auto* rhs = instruction->operand(1); 2974 const auto* ehs = instruction->operand(2); 2975 TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape())); 2976 TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); 2977 TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape())); 2978 2979 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 2980 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 2981 const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); 2982 2983 Literal result(shape); 2984 2985 TF_RETURN_IF_ERROR( 2986 result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { 2987 return ternary_op(lhs_literal.Get<LhsType>(multi_index), 2988 rhs_literal.Get<RhsType>(multi_index), 2989 ehs_literal.Get<EhsType>(multi_index)); 2990 })); 2991 2992 return std::move(result); 2993 } 2994 2995 template <typename NativeT> 2996 static bool IsShiftOutOfBounds(NativeT rhs) { 2997 typedef typename std::make_unsigned<NativeT>::type UnsignedT; 2998 UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; 2999 UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs); 3000 return rhs_unsigned >= lhs_size_unsigned; 3001 } 3002 3003 HloEvaluator* parent_; 3004 }; 3005 3006 // These extern templates prevent users of this class from implicitly 3007 // instantiating it. We explicitly instantiate this class in the various 3008 // hlo_evaluator_typed_visitor*.cc files. 3009 extern template class HloEvaluatorTypedVisitor<bool>; 3010 extern template class HloEvaluatorTypedVisitor<uint8>; 3011 extern template class HloEvaluatorTypedVisitor<uint32>; 3012 extern template class HloEvaluatorTypedVisitor<uint64>; 3013 extern template class HloEvaluatorTypedVisitor<int8>; 3014 extern template class HloEvaluatorTypedVisitor<int32>; 3015 extern template class HloEvaluatorTypedVisitor<int64>; 3016 extern template class HloEvaluatorTypedVisitor<Eigen::half, float>; 3017 extern template class HloEvaluatorTypedVisitor<float>; 3018 extern template class HloEvaluatorTypedVisitor<double>; 3019 extern template class HloEvaluatorTypedVisitor<complex64>; 3020 extern template class HloEvaluatorTypedVisitor<complex128>; 3021 extern template class HloEvaluatorTypedVisitor<bfloat16, float>; 3022 3023 } // namespace xla 3024 3025 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ 3026